[tflite]: Insert nullptr checks when obtaining tensors.

As part of ongoing refactoring, `tflite::GetInput`, `tflite::GetOutput`, `tflite::GetTemporary` and `tflite::GetIntermediates` will return `nullptr` in some cases. Hence, we insert the `nullptr` checks on all usages.

We also insert `nullptr` checks on usages of `tflite::GetVariableInput` and `tflite::GetOptionalInputTensor` but only in the cases where there is no obvious check that `nullptr` is acceptable (that is, we only insert the check for the output of these two functions if the tensor is accessed as if it is always not `nullptr`).

PiperOrigin-RevId: 332521299
Change-Id: I29af455bcb48d0b92e58132d951a3badbd772d56
This commit is contained in:
Mihai Maruseac 2020-09-18 13:56:43 -07:00 committed by TensorFlower Gardener
parent fff2c83262
commit 1970c2158b
83 changed files with 2720 additions and 1203 deletions

View File

@ -252,8 +252,10 @@ void* HardSwishInit(TfLiteContext* context, const char* buffer, size_t length) {
TfLiteStatus GenericPrepare(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_ENSURE_EQ(context, NumInputs(node), 1);
TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
const TfLiteTensor* input = GetInput(context, node, 0);
TfLiteTensor* output = GetOutput(context, node, 0);
const TfLiteTensor* input;
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 0, &input));
TfLiteTensor* output;
TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, 0, &output));
TF_LITE_ENSURE_TYPES_EQ(context, input->type, output->type);
return context->ResizeTensor(context, output,
@ -272,8 +274,10 @@ TfLiteStatus ReluPrepare(TfLiteContext* context, TfLiteNode* node) {
ReluOpData* data = reinterpret_cast<ReluOpData*>(node->user_data);
TF_LITE_ENSURE_EQ(context, NumInputs(node), 1);
TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
const TfLiteTensor* input = GetInput(context, node, 0);
TfLiteTensor* output = GetOutput(context, node, 0);
const TfLiteTensor* input;
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 0, &input));
TfLiteTensor* output;
TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, 0, &output));
TF_LITE_ENSURE_TYPES_EQ(context, input->type, output->type);
if (input->type == kTfLiteInt8 || input->type == kTfLiteUInt8) {
@ -300,12 +304,14 @@ void HardSwishFree(TfLiteContext* context, void* buffer) {
TfLiteStatus HardSwishPrepare(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_ENSURE_STATUS(GenericPrepare(context, node));
TfLiteTensor* output = GetOutput(context, node, 0);
TfLiteTensor* output;
TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, 0, &output));
if (output->type == kTfLiteUInt8 || output->type == kTfLiteInt8) {
HardSwishData* data = static_cast<HardSwishData*>(node->user_data);
HardSwishParams* params = &data->params;
const TfLiteTensor* input = GetInput(context, node, 0);
const TfLiteTensor* input;
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 0, &input));
params->input_zero_point = input->params.zero_point;
params->output_zero_point = output->params.zero_point;
const float input_scale = input->params.scale;
@ -337,8 +343,10 @@ TfLiteStatus HardSwishPrepare(TfLiteContext* context, TfLiteNode* node) {
TfLiteStatus LeakyReluPrepare(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_ENSURE_EQ(context, NumInputs(node), 1);
TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
const TfLiteTensor* input = GetInput(context, node, 0);
TfLiteTensor* output = GetOutput(context, node, 0);
const TfLiteTensor* input;
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 0, &input));
TfLiteTensor* output;
TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, 0, &output));
TF_LITE_ENSURE_TYPES_EQ(context, input->type, output->type);
LeakyReluOpData* data = reinterpret_cast<LeakyReluOpData*>(node->user_data);
@ -366,8 +374,10 @@ TfLiteStatus TanhPrepare(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_ENSURE_EQ(context, NumInputs(node), 1);
TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
const TfLiteTensor* input = GetInput(context, node, 0);
TfLiteTensor* output = GetOutput(context, node, 0);
const TfLiteTensor* input;
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 0, &input));
TfLiteTensor* output;
TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, 0, &output));
TF_LITE_ENSURE_TYPES_EQ(context, input->type, output->type);
if (kernel_type == kFixedPointOptimized) {
@ -451,8 +461,10 @@ TfLiteStatus SigmoidPrepare(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_ENSURE_EQ(context, NumInputs(node), 1);
TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
const TfLiteTensor* input = GetInput(context, node, 0);
TfLiteTensor* output = GetOutput(context, node, 0);
const TfLiteTensor* input;
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 0, &input));
TfLiteTensor* output;
TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, 0, &output));
TF_LITE_ENSURE_TYPES_EQ(context, input->type, output->type);
if (kernel_type == kFixedPointOptimized) {
@ -546,8 +558,10 @@ TfLiteStatus SoftmaxPrepare(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_ENSURE_EQ(context, NumInputs(node), 1);
TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
const TfLiteTensor* input = GetInput(context, node, 0);
TfLiteTensor* output = GetOutput(context, node, 0);
const TfLiteTensor* input;
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 0, &input));
TfLiteTensor* output;
TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, 0, &output));
if (output->type == kTfLiteInt16) {
TF_LITE_ENSURE(context, input->type == kTfLiteInt8 ||
input->type == kTfLiteUInt8 ||
@ -614,8 +628,10 @@ TfLiteStatus LogSoftmaxPrepare(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_ENSURE_EQ(context, NumInputs(node), 1);
TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
const TfLiteTensor* input = GetInput(context, node, 0);
TfLiteTensor* output = GetOutput(context, node, 0);
const TfLiteTensor* input;
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 0, &input));
TfLiteTensor* output;
TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, 0, &output));
TF_LITE_ENSURE_TYPES_EQ(context, input->type, output->type);
if (input->type == kTfLiteUInt8 || input->type == kTfLiteInt8) {
@ -650,9 +666,12 @@ TfLiteStatus LogSoftmaxPrepare(TfLiteContext* context, TfLiteNode* node) {
TfLiteStatus PreluPrepare(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_ENSURE_EQ(context, NumInputs(node), 2);
TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
const TfLiteTensor* input = GetInput(context, node, 0);
TfLiteTensor* output = GetOutput(context, node, 0);
const TfLiteTensor* alpha = GetInput(context, node, 1);
const TfLiteTensor* input;
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 0, &input));
TfLiteTensor* output;
TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, 0, &output));
const TfLiteTensor* alpha;
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 1, &alpha));
PreluOpData* data = reinterpret_cast<PreluOpData*>(node->user_data);
TF_LITE_ENSURE_TYPES_EQ(context, input->type, alpha->type);
@ -704,8 +723,10 @@ TfLiteStatus PreluPrepare(TfLiteContext* context, TfLiteNode* node) {
}
TfLiteStatus ReluEval(TfLiteContext* context, TfLiteNode* node) {
const TfLiteTensor* input = GetInput(context, node, 0);
TfLiteTensor* output = GetOutput(context, node, 0);
const TfLiteTensor* input;
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 0, &input));
TfLiteTensor* output;
TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, 0, &output));
const ReluOpData* data = reinterpret_cast<ReluOpData*>(node->user_data);
switch (input->type) {
case kTfLiteFloat32: {
@ -732,8 +753,10 @@ TfLiteStatus ReluEval(TfLiteContext* context, TfLiteNode* node) {
}
TfLiteStatus Relu1Eval(TfLiteContext* context, TfLiteNode* node) {
const TfLiteTensor* input = GetInput(context, node, 0);
TfLiteTensor* output = GetOutput(context, node, 0);
const TfLiteTensor* input;
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 0, &input));
TfLiteTensor* output;
TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, 0, &output));
const ReluOpData* data = reinterpret_cast<ReluOpData*>(node->user_data);
switch (input->type) {
case kTfLiteFloat32: {
@ -763,8 +786,10 @@ template <KernelType kernel_type>
TfLiteStatus HardSwishEval(TfLiteContext* context, TfLiteNode* node) {
HardSwishData* data = static_cast<HardSwishData*>(node->user_data);
const TfLiteTensor* input = GetInput(context, node, 0);
TfLiteTensor* output = GetOutput(context, node, 0);
const TfLiteTensor* input;
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 0, &input));
TfLiteTensor* output;
TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, 0, &output));
switch (input->type) {
case kTfLiteFloat32: {
if (kernel_type == kReference) {
@ -814,8 +839,10 @@ TfLiteStatus HardSwishEval(TfLiteContext* context, TfLiteNode* node) {
}
TfLiteStatus Relu6Eval(TfLiteContext* context, TfLiteNode* node) {
const TfLiteTensor* input = GetInput(context, node, 0);
TfLiteTensor* output = GetOutput(context, node, 0);
const TfLiteTensor* input;
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 0, &input));
TfLiteTensor* output;
TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, 0, &output));
ReluOpData* data = reinterpret_cast<ReluOpData*>(node->user_data);
switch (input->type) {
case kTfLiteFloat32: {
@ -845,8 +872,10 @@ TfLiteStatus Relu6Eval(TfLiteContext* context, TfLiteNode* node) {
template <KernelType kernel_type>
TfLiteStatus TanhEval(TfLiteContext* context, TfLiteNode* node) {
OpData* data = reinterpret_cast<OpData*>(node->user_data);
const TfLiteTensor* input = GetInput(context, node, 0);
TfLiteTensor* output = GetOutput(context, node, 0);
const TfLiteTensor* input;
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 0, &input));
TfLiteTensor* output;
TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, 0, &output));
switch (input->type) {
case kTfLiteFloat32: {
if (kernel_type == kReference) {
@ -919,8 +948,10 @@ template <KernelType kernel_type>
TfLiteStatus SigmoidEval(TfLiteContext* context, TfLiteNode* node) {
OpData* data = reinterpret_cast<OpData*>(node->user_data);
const TfLiteTensor* input = GetInput(context, node, 0);
TfLiteTensor* output = GetOutput(context, node, 0);
const TfLiteTensor* input;
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 0, &input));
TfLiteTensor* output;
TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, 0, &output));
switch (input->type) {
case kTfLiteFloat32: {
if (kernel_type == kReference) {
@ -1067,8 +1098,10 @@ TfLiteStatus SoftmaxEval(TfLiteContext* context, TfLiteNode* node) {
auto* params = reinterpret_cast<TfLiteSoftmaxParams*>(node->builtin_data);
SoftmaxOpData* data = reinterpret_cast<SoftmaxOpData*>(node->user_data);
const TfLiteTensor* input = GetInput(context, node, 0);
TfLiteTensor* output = GetOutput(context, node, 0);
const TfLiteTensor* input;
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 0, &input));
TfLiteTensor* output;
TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, 0, &output));
switch (input->type) {
case kTfLiteFloat32: {
@ -1122,8 +1155,10 @@ template <KernelType kernel_type>
TfLiteStatus LogSoftmaxEval(TfLiteContext* context, TfLiteNode* node) {
const LogSoftmaxOpData* data =
reinterpret_cast<LogSoftmaxOpData*>(node->user_data);
const TfLiteTensor* input = GetInput(context, node, 0);
TfLiteTensor* output = GetOutput(context, node, 0);
const TfLiteTensor* input;
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 0, &input));
TfLiteTensor* output;
TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, 0, &output));
switch (input->type) {
case kTfLiteFloat32: {
SoftmaxParams op_params;
@ -1183,9 +1218,12 @@ T ApplyPrelu(T input, T alpha) {
template <KernelType kernel_type>
TfLiteStatus PreluEval(TfLiteContext* context, TfLiteNode* node) {
const TfLiteTensor* input = GetInput(context, node, 0);
const TfLiteTensor* alpha = GetInput(context, node, 1);
TfLiteTensor* output = GetOutput(context, node, 0);
const TfLiteTensor* input;
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 0, &input));
const TfLiteTensor* alpha;
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 1, &alpha));
TfLiteTensor* output;
TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, 0, &output));
const PreluOpData* data = reinterpret_cast<PreluOpData*>(node->user_data);
switch (input->type) {
case kTfLiteFloat32: {
@ -1294,8 +1332,10 @@ void QuantizeLeakyRelu(const TfLiteTensor* input, TfLiteTensor* output,
}
TfLiteStatus LeakyReluEval(TfLiteContext* context, TfLiteNode* node) {
const TfLiteTensor* input = GetInput(context, node, 0);
TfLiteTensor* output = GetOutput(context, node, 0);
const TfLiteTensor* input;
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 0, &input));
TfLiteTensor* output;
TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, 0, &output));
const auto* params =
reinterpret_cast<TfLiteLeakyReluParams*>(node->builtin_data);
const LeakyReluOpData* data =
@ -1332,8 +1372,10 @@ TfLiteStatus LeakyReluEval(TfLiteContext* context, TfLiteNode* node) {
}
TfLiteStatus EluPrepare(TfLiteContext* context, TfLiteNode* node) {
const TfLiteTensor* input = GetInput(context, node, 0);
TfLiteTensor* output = GetOutput(context, node, 0);
const TfLiteTensor* input;
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 0, &input));
TfLiteTensor* output;
TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, 0, &output));
OpData* data = reinterpret_cast<OpData*>(node->user_data);
// Use LUT to handle quantized elu path.
@ -1346,8 +1388,10 @@ TfLiteStatus EluPrepare(TfLiteContext* context, TfLiteNode* node) {
}
TfLiteStatus EluEval(TfLiteContext* context, TfLiteNode* node) {
const TfLiteTensor* input = GetInput(context, node, 0);
TfLiteTensor* output = GetOutput(context, node, 0);
const TfLiteTensor* input;
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 0, &input));
TfLiteTensor* output;
TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, 0, &output));
switch (input->type) {
case kTfLiteFloat32: {
optimized_ops::Elu(GetTensorShape(input), GetTensorData<float>(input),

View File

@ -91,9 +91,15 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_ENSURE_EQ(context, NumInputs(node), 2);
TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
const TfLiteTensor* input1 = GetInput(context, node, kInputTensor1);
const TfLiteTensor* input2 = GetInput(context, node, kInputTensor2);
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
const TfLiteTensor* input1;
TF_LITE_ENSURE_OK(context,
GetInputSafe(context, node, kInputTensor1, &input1));
const TfLiteTensor* input2;
TF_LITE_ENSURE_OK(context,
GetInputSafe(context, node, kInputTensor2, &input2));
TfLiteTensor* output;
TF_LITE_ENSURE_OK(context,
GetOutputSafe(context, node, kOutputTensor, &output));
TF_LITE_ENSURE_TYPES_EQ(context, input1->type, input2->type);
output->type = input2->type;
@ -358,9 +364,15 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
auto* params = reinterpret_cast<TfLiteAddParams*>(node->builtin_data);
OpData* data = reinterpret_cast<OpData*>(node->user_data);
const TfLiteTensor* input1 = GetInput(context, node, kInputTensor1);
const TfLiteTensor* input2 = GetInput(context, node, kInputTensor2);
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
const TfLiteTensor* input1;
TF_LITE_ENSURE_OK(context,
GetInputSafe(context, node, kInputTensor1, &input1));
const TfLiteTensor* input2;
TF_LITE_ENSURE_OK(context,
GetInputSafe(context, node, kInputTensor2, &input2));
TfLiteTensor* output;
TF_LITE_ENSURE_OK(context,
GetOutputSafe(context, node, kOutputTensor, &output));
if (output->type == kTfLiteFloat32 || output->type == kTfLiteInt32) {
EvalAdd<kernel_type>(context, node, params, data, input1, input2, output);

View File

@ -33,13 +33,18 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_ENSURE(context, num_inputs >= 2);
TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
const TfLiteTensor* input1 = GetInput(context, node, kInputTensor1);
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
const TfLiteTensor* input1;
TF_LITE_ENSURE_OK(context,
GetInputSafe(context, node, kInputTensor1, &input1));
TfLiteTensor* output;
TF_LITE_ENSURE_OK(context,
GetOutputSafe(context, node, kOutputTensor, &output));
output->type = input1->type;
// Check that all input tensors have the same shape and type.
for (int i = kInputTensor1 + 1; i < num_inputs; ++i) {
const TfLiteTensor* input = GetInput(context, node, i);
const TfLiteTensor* input;
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, i, &input));
TF_LITE_ENSURE(context, HaveSameShapes(input1, input));
TF_LITE_ENSURE_TYPES_EQ(context, input1->type, input->type);
}
@ -55,15 +60,22 @@ template <typename T>
void EvalAddN(TfLiteContext* context, TfLiteNode* node) {
// TODO(haoliang): Initialize all_inputs only once during init.
VectorOfTensors<T> all_inputs(*context, *node->inputs);
// Safe to use unchecked since caller checks that tensor is valid
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
int num_inputs = NumInputs(node);
// Safe to use unchecked since caller checks that tensor is valid
const TfLiteTensor* input1 = GetInput(context, node, kInputTensor1);
reference_ops::AddN<T>(GetTensorShape(input1), num_inputs, all_inputs.data(),
GetTensorData<T>(output));
}
TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
const TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
const TfLiteTensor* input1;
TF_LITE_ENSURE_OK(context,
GetInputSafe(context, node, kInputTensor1, &input1));
TfLiteTensor* output;
TF_LITE_ENSURE_OK(context,
GetOutputSafe(context, node, kOutputTensor, &output));
if (output->type == kTfLiteFloat32) {
EvalAddN<float>(context, node);
} else if (output->type == kTfLiteInt32) {

View File

@ -58,15 +58,19 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_ENSURE_EQ(context, NumInputs(node), 2);
TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
const TfLiteTensor* input = GetInput(context, node, kInputTensor);
const TfLiteTensor* axis = GetInput(context, node, kAxis);
const TfLiteTensor* input;
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputTensor, &input));
const TfLiteTensor* axis;
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kAxis, &axis));
// Make sure the axis is only 1 dimension.
TF_LITE_ENSURE_EQ(context, NumElements(axis), 1);
// Make sure the axis is only either int32 or int64.
TF_LITE_ENSURE(context,
axis->type == kTfLiteInt32 || axis->type == kTfLiteInt64);
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
TfLiteTensor* output;
TF_LITE_ENSURE_OK(context,
GetOutputSafe(context, node, kOutputTensor, &output));
auto* params = reinterpret_cast<TfLiteArgMaxParams*>(node->builtin_data);
switch (params->output_type) {
@ -119,9 +123,13 @@ std::function<bool(T, T)> GetComparefunction(bool is_arg_max) {
}
TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node, bool is_arg_max) {
const TfLiteTensor* input = GetInput(context, node, kInputTensor);
const TfLiteTensor* axis = GetInput(context, node, kAxis);
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
const TfLiteTensor* input;
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputTensor, &input));
const TfLiteTensor* axis;
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kAxis, &axis));
TfLiteTensor* output;
TF_LITE_ENSURE_OK(context,
GetOutputSafe(context, node, kOutputTensor, &output));
if (IsDynamicTensor(output)) {
TF_LITE_ENSURE_STATUS(ResizeOutput(context, input, axis, output));
}

View File

@ -40,8 +40,9 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
// everything still works fine when variable ops aren't used.
TF_LITE_ENSURE_EQ(context, NumOutputs(node), 0);
const TfLiteTensor* input_resource_id_tensor =
GetInput(context, node, kInputVariableId);
const TfLiteTensor* input_resource_id_tensor;
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputVariableId,
&input_resource_id_tensor));
TF_LITE_ENSURE_EQ(context, input_resource_id_tensor->type, kTfLiteInt32);
TF_LITE_ENSURE_EQ(context, NumElements(input_resource_id_tensor), 1);
@ -51,9 +52,12 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
Subgraph* subgraph = reinterpret_cast<Subgraph*>(context->impl_);
const TfLiteTensor* input_resource_id_tensor =
GetInput(context, node, kInputVariableId);
const TfLiteTensor* input_value_tensor = GetInput(context, node, kInputValue);
const TfLiteTensor* input_resource_id_tensor;
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputVariableId,
&input_resource_id_tensor));
const TfLiteTensor* input_value_tensor;
TF_LITE_ENSURE_OK(
context, GetInputSafe(context, node, kInputValue, &input_value_tensor));
int resource_id = input_resource_id_tensor->data.i32[0];
auto& resources = subgraph->resources();

View File

@ -76,8 +76,11 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_ENSURE_EQ(context, NumInputs(node), 1);
TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
const TfLiteTensor* input = GetInput(context, node, kInputTensor);
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
const TfLiteTensor* input;
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputTensor, &input));
TfLiteTensor* output;
TF_LITE_ENSURE_OK(context,
GetOutputSafe(context, node, kOutputTensor, &output));
TF_LITE_ENSURE_EQ(context, NumDimensions(input), 2);
@ -106,8 +109,11 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
auto* params =
reinterpret_cast<TfLiteAudioSpectrogramParams*>(node->user_data);
const TfLiteTensor* input = GetInput(context, node, kInputTensor);
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
const TfLiteTensor* input;
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputTensor, &input));
TfLiteTensor* output;
TF_LITE_ENSURE_OK(context,
GetOutputSafe(context, node, kOutputTensor, &output));
TF_LITE_ENSURE(context, params->spectrogram->Initialize(params->window_size,
params->stride));

View File

@ -60,13 +60,20 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_ENSURE_EQ(context, node->inputs->size, 5);
TF_LITE_ENSURE_EQ(context, node->outputs->size, 1);
const TfLiteTensor* input = GetInput(context, node, kInputTensor);
const TfLiteTensor* input_weights = GetInput(context, node, kWeightsTensor);
const TfLiteTensor* recurrent_weights =
GetInput(context, node, kRecurrentWeightsTensor);
const TfLiteTensor* bias = GetInput(context, node, kBiasTensor);
const TfLiteTensor* hidden_state =
GetInput(context, node, kHiddenStateTensor);
const TfLiteTensor* input;
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputTensor, &input));
const TfLiteTensor* input_weights;
TF_LITE_ENSURE_OK(
context, GetInputSafe(context, node, kWeightsTensor, &input_weights));
const TfLiteTensor* recurrent_weights;
TF_LITE_ENSURE_OK(
context,
GetInputSafe(context, node, kRecurrentWeightsTensor, &recurrent_weights));
const TfLiteTensor* bias;
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kBiasTensor, &bias));
const TfLiteTensor* hidden_state;
TF_LITE_ENSURE_OK(
context, GetInputSafe(context, node, kHiddenStateTensor, &hidden_state));
// Check all the parameters of tensor match within themselves and match the
// input configuration.
@ -86,7 +93,9 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_ENSURE_EQ(context, hidden_state->dims->data[0], batch_size);
TF_LITE_ENSURE_EQ(context, hidden_state->dims->data[1], num_units);
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
TfLiteTensor* output;
TF_LITE_ENSURE_OK(context,
GetOutputSafe(context, node, kOutputTensor, &output));
// Resize output.
TfLiteIntArray* output_size_array = TfLiteIntArrayCreate(2);
@ -105,7 +114,9 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
TfLiteIntArrayFree(node->temporaries);
node->temporaries = TfLiteIntArrayCreate(6);
node->temporaries->data[0] = op_data->scratch_tensor_index;
TfLiteTensor* input_quantized = GetTemporary(context, node, /*index=*/0);
TfLiteTensor* input_quantized;
TF_LITE_ENSURE_OK(context, GetTemporarySafe(context, node, /*index=*/0,
&input_quantized));
input_quantized->type = input_weights->type;
input_quantized->allocation_type = kTfLiteArenaRw;
if (!TfLiteIntArrayEqual(input_quantized->dims, input->dims)) {
@ -114,8 +125,9 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
input_quantized_size));
}
node->temporaries->data[1] = op_data->scratch_tensor_index + 1;
TfLiteTensor* hidden_state_quantized =
GetTemporary(context, node, /*index=*/1);
TfLiteTensor* hidden_state_quantized;
TF_LITE_ENSURE_OK(context, GetTemporarySafe(context, node, /*index=*/1,
&hidden_state_quantized));
hidden_state_quantized->type = input_weights->type;
hidden_state_quantized->allocation_type = kTfLiteArenaRw;
if (!TfLiteIntArrayEqual(hidden_state_quantized->dims,
@ -127,7 +139,9 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
hidden_state_quantized_size));
}
node->temporaries->data[2] = op_data->scratch_tensor_index + 2;
TfLiteTensor* scaling_factors = GetTemporary(context, node, /*index=*/2);
TfLiteTensor* scaling_factors;
TF_LITE_ENSURE_OK(context, GetTemporarySafe(context, node, /*index=*/2,
&scaling_factors));
scaling_factors->type = kTfLiteFloat32;
scaling_factors->allocation_type = kTfLiteArenaRw;
int scaling_dims[1] = {batch_size};
@ -138,7 +152,9 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
scaling_factors_size));
}
node->temporaries->data[3] = op_data->scratch_tensor_index + 3;
TfLiteTensor* accum_scratch = GetTemporary(context, node, /*index=*/3);
TfLiteTensor* accum_scratch;
TF_LITE_ENSURE_OK(
context, GetTemporarySafe(context, node, /*index=*/3, &accum_scratch));
accum_scratch->type = kTfLiteInt32;
accum_scratch->allocation_type = kTfLiteArenaRw;
int accum_scratch_dims[2] = {num_units, batch_size};
@ -151,7 +167,9 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
accum_scratch_size));
}
node->temporaries->data[4] = op_data->scratch_tensor_index + 4;
TfLiteTensor* zero_points = GetTemporary(context, node, /*index=*/4);
TfLiteTensor* zero_points;
TF_LITE_ENSURE_OK(
context, GetTemporarySafe(context, node, /*index=*/4, &zero_points));
zero_points->type = kTfLiteInt32;
zero_points->allocation_type = kTfLiteArenaRw;
int zero_points_dims[1] = {batch_size};
@ -162,7 +180,9 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
zero_points_size));
}
node->temporaries->data[5] = op_data->scratch_tensor_index + 5;
TfLiteTensor* row_sums = GetTemporary(context, node, /*index=*/5);
TfLiteTensor* row_sums;
TF_LITE_ENSURE_OK(context,
GetTemporarySafe(context, node, /*index=*/5, &row_sums));
row_sums->type = kTfLiteInt32;
row_sums->allocation_type = kTfLiteArenaRwPersistent;
int row_sums_dims[2] = {2, num_units};
@ -260,14 +280,23 @@ TfLiteStatus EvalHybrid(const TfLiteTensor* input,
TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
auto* params = reinterpret_cast<TfLiteRNNParams*>(node->builtin_data);
auto* op_data = reinterpret_cast<OpData*>(node->user_data);
const TfLiteTensor* input = GetInput(context, node, kInputTensor);
const TfLiteTensor* input_weights = GetInput(context, node, kWeightsTensor);
const TfLiteTensor* recurrent_weights =
GetInput(context, node, kRecurrentWeightsTensor);
const TfLiteTensor* bias = GetInput(context, node, kBiasTensor);
const TfLiteTensor* input;
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputTensor, &input));
const TfLiteTensor* input_weights;
TF_LITE_ENSURE_OK(
context, GetInputSafe(context, node, kWeightsTensor, &input_weights));
const TfLiteTensor* recurrent_weights;
TF_LITE_ENSURE_OK(
context,
GetInputSafe(context, node, kRecurrentWeightsTensor, &recurrent_weights));
const TfLiteTensor* bias;
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kBiasTensor, &bias));
TfLiteTensor* hidden_state =
&context->tensors[node->inputs->data[kHiddenStateTensor]];
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
GetVariableInput(context, node, kHiddenStateTensor);
TF_LITE_ENSURE(context, hidden_state != nullptr);
TfLiteTensor* output;
TF_LITE_ENSURE_OK(context,
GetOutputSafe(context, node, kOutputTensor, &output));
// We already checked that weight types are consistent, so branch on one.
switch (input_weights->type) {
@ -277,12 +306,23 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
case kTfLiteUInt8:
case kTfLiteInt8: {
// TODO(mirkov): implement eval with quantized inputs as well.
TfLiteTensor* input_quantized = GetTemporary(context, node, 0);
TfLiteTensor* hidden_state_quantized = GetTemporary(context, node, 1);
TfLiteTensor* scaling_factors = GetTemporary(context, node, 2);
TfLiteTensor* accum_scratch = GetTemporary(context, node, 3);
TfLiteTensor* zero_points = GetTemporary(context, node, 4);
TfLiteTensor* row_sums = GetTemporary(context, node, 5);
TfLiteTensor* input_quantized;
TF_LITE_ENSURE_OK(context,
GetTemporarySafe(context, node, 0, &input_quantized));
TfLiteTensor* hidden_state_quantized;
TF_LITE_ENSURE_OK(
context, GetTemporarySafe(context, node, 1, &hidden_state_quantized));
TfLiteTensor* scaling_factors;
TF_LITE_ENSURE_OK(context,
GetTemporarySafe(context, node, 2, &scaling_factors));
TfLiteTensor* accum_scratch;
TF_LITE_ENSURE_OK(context,
GetTemporarySafe(context, node, 3, &accum_scratch));
TfLiteTensor* zero_points;
TF_LITE_ENSURE_OK(context,
GetTemporarySafe(context, node, 4, &zero_points));
TfLiteTensor* row_sums;
TF_LITE_ENSURE_OK(context, GetTemporarySafe(context, node, 5, &row_sums));
return EvalHybrid(input, input_weights, recurrent_weights, bias, params,
input_quantized, hidden_state_quantized,
scaling_factors, hidden_state, output, zero_points,

View File

@ -154,7 +154,9 @@ TfLiteStatus InitializeTemporaries(TfLiteContext* context, TfLiteNode* node,
// Temp tensor for Transposed LHS;
{
node->temporaries->data[0] = op_data->scratch_tensor_index;
TfLiteTensor* scratch_buffer = GetTemporary(context, node, /*index=*/0);
TfLiteTensor* scratch_buffer;
TF_LITE_ENSURE_OK(
context, GetTemporarySafe(context, node, /*index=*/0, &scratch_buffer));
TfLiteIntArray* scratch_buffer_size = TfLiteIntArrayCreate(lhs_rank);
for (int i = 0; i < lhs_rank - 2; ++i) {
scratch_buffer_size->data[i] = lhs->dims->data[i];
@ -175,7 +177,9 @@ TfLiteStatus InitializeTemporaries(TfLiteContext* context, TfLiteNode* node,
// is set by the caller, the data is already in the desired layout.
{
node->temporaries->data[1] = op_data->scratch_tensor_index + 1;
TfLiteTensor* scratch_buffer = GetTemporary(context, node, /*index=*/1);
TfLiteTensor* scratch_buffer;
TF_LITE_ENSURE_OK(
context, GetTemporarySafe(context, node, /*index=*/1, &scratch_buffer));
const TfLiteTensor* rhs = op_context->rhs;
int rhs_rank = NumDimensions(rhs);
TfLiteIntArray* scratch_buffer_size = TfLiteIntArrayCreate(rhs_rank);
@ -215,7 +219,9 @@ TfLiteStatus InitializeTemporaries(TfLiteContext* context, TfLiteNode* node,
}
op_data->compute_row_sums = true;
node->temporaries->data[2] = op_data->scratch_tensor_index + 2;
TfLiteTensor* input_quantized = GetTemporary(context, node, /*index=*/2);
TfLiteTensor* input_quantized;
TF_LITE_ENSURE_OK(context, GetTemporarySafe(context, node, /*index=*/2,
&input_quantized));
input_quantized->type = op_context->rhs->type;
input_quantized->allocation_type = kTfLiteArenaRw;
@ -225,7 +231,9 @@ TfLiteStatus InitializeTemporaries(TfLiteContext* context, TfLiteNode* node,
input_quantized_size));
node->temporaries->data[3] = op_data->scratch_tensor_index + 3;
TfLiteTensor* scaling_factors = GetTemporary(context, node, /*index=*/3);
TfLiteTensor* scaling_factors;
TF_LITE_ENSURE_OK(context, GetTemporarySafe(context, node, /*index=*/3,
&scaling_factors));
scaling_factors->type = kTfLiteFloat32;
scaling_factors->allocation_type = kTfLiteArenaRw;
// Total size of scaling factors is batch size * number of total batches
@ -238,7 +246,9 @@ TfLiteStatus InitializeTemporaries(TfLiteContext* context, TfLiteNode* node,
}
node->temporaries->data[4] = op_data->scratch_tensor_index + 4;
TfLiteTensor* accum_scratch = GetTemporary(context, node, /*index=*/4);
TfLiteTensor* accum_scratch;
TF_LITE_ENSURE_OK(
context, GetTemporarySafe(context, node, /*index=*/4, &accum_scratch));
accum_scratch->type = kTfLiteInt32;
accum_scratch->allocation_type = kTfLiteArenaRw;
int accum_scratch_dims[2] = {num_units, batch_size};
@ -252,7 +262,9 @@ TfLiteStatus InitializeTemporaries(TfLiteContext* context, TfLiteNode* node,
}
node->temporaries->data[5] = op_data->scratch_tensor_index + 5;
TfLiteTensor* input_offsets = GetTemporary(context, node, /*index=*/5);
TfLiteTensor* input_offsets;
TF_LITE_ENSURE_OK(
context, GetTemporarySafe(context, node, /*index=*/5, &input_offsets));
input_offsets->type = kTfLiteInt32;
input_offsets->allocation_type = kTfLiteArenaRw;
if (!TfLiteIntArrayEqualsArray(input_offsets->dims, 1, scaling_dims)) {
@ -262,7 +274,9 @@ TfLiteStatus InitializeTemporaries(TfLiteContext* context, TfLiteNode* node,
input_offsets_size));
}
node->temporaries->data[6] = op_data->scratch_tensor_index + 6;
TfLiteTensor* row_sums = GetTemporary(context, node, /*index=*/6);
TfLiteTensor* row_sums;
TF_LITE_ENSURE_OK(context,
GetTemporarySafe(context, node, /*index=*/6, &row_sums));
row_sums->type = kTfLiteInt32;
row_sums->allocation_type = kTfLiteArenaRwPersistent;
int row_sums_dims[1] = {num_weights_matrices * num_units};
@ -288,9 +302,15 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
bool adj_x = op_context.params->adj_x;
bool adj_y = op_context.params->adj_y;
const TfLiteTensor* lhs_data = GetInput(context, node, kInputLHSTensor);
const TfLiteTensor* rhs_data = GetInput(context, node, kInputRHSTensor);
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
const TfLiteTensor* lhs_data;
TF_LITE_ENSURE_OK(context,
GetInputSafe(context, node, kInputLHSTensor, &lhs_data));
const TfLiteTensor* rhs_data;
TF_LITE_ENSURE_OK(context,
GetInputSafe(context, node, kInputRHSTensor, &rhs_data));
TfLiteTensor* output;
TF_LITE_ENSURE_OK(context,
GetOutputSafe(context, node, kOutputTensor, &output));
// Note that quantized inference requires that all tensors have their
// parameters set. This is usually done during quantized training.
@ -502,11 +522,21 @@ TfLiteStatus EvalQuantized(TfLiteContext* context, TfLiteNode* node,
const RuntimeShape& rhs_shape,
const TfLiteTensor* rhs, TfLiteTensor* output) {
if (lhs->type == kTfLiteFloat32) {
TfLiteTensor* input_quantized = GetTemporary(context, node, /*index=*/2);
TfLiteTensor* scaling_factors = GetTemporary(context, node, /*index=*/3);
TfLiteTensor* accum_scratch = GetTemporary(context, node, /*index=*/4);
TfLiteTensor* input_offsets = GetTemporary(context, node, /*index=*/5);
TfLiteTensor* row_sums = GetTemporary(context, node, /*index=*/6);
TfLiteTensor* input_quantized;
TF_LITE_ENSURE_OK(context, GetTemporarySafe(context, node, /*index=*/2,
&input_quantized));
TfLiteTensor* scaling_factors;
TF_LITE_ENSURE_OK(context, GetTemporarySafe(context, node, /*index=*/3,
&scaling_factors));
TfLiteTensor* accum_scratch;
TF_LITE_ENSURE_OK(
context, GetTemporarySafe(context, node, /*index=*/4, &accum_scratch));
TfLiteTensor* input_offsets;
TF_LITE_ENSURE_OK(
context, GetTemporarySafe(context, node, /*index=*/5, &input_offsets));
TfLiteTensor* row_sums;
TF_LITE_ENSURE_OK(context,
GetTemporarySafe(context, node, /*index=*/6, &row_sums));
return EvalHybrid<kernel_type>(
context, node, data, lhs_shape, lhs, rhs_shape, rhs, input_quantized,
scaling_factors, accum_scratch, row_sums, input_offsets, output);
@ -524,6 +554,10 @@ TfLiteStatus EvalQuantized(TfLiteContext* context, TfLiteNode* node,
TfLiteTensor* GetTempRhs(TfLiteContext* context, TfLiteNode* node,
const TfLiteTensor* rhs) {
TfLiteTensor* transposed_rhs = GetTemporary(context, node, 1);
if (transposed_rhs == nullptr) {
return nullptr;
}
if (rhs->type == kTfLiteInt8) {
// Get the quantization params from the RHS tensor.
transposed_rhs->params.scale = rhs->params.scale;
@ -535,6 +569,10 @@ TfLiteTensor* GetTempRhs(TfLiteContext* context, TfLiteNode* node,
TfLiteTensor* GetTempLhs(TfLiteContext* context, TfLiteNode* node,
const TfLiteTensor* lhs) {
TfLiteTensor* transposed_lhs = GetTemporary(context, node, 0);
if (transposed_lhs == nullptr) {
return nullptr;
}
if (lhs->type == kTfLiteInt8) {
// Get the quantization params from the LHS tensor.
transposed_lhs->params.scale = lhs->params.scale;
@ -558,9 +596,15 @@ template <KernelType kernel_type>
TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
OpContext op_context(context, node);
OpData* op_data = reinterpret_cast<OpData*>(node->user_data);
const TfLiteTensor* lhs = GetInput(context, node, kInputLHSTensor);
const TfLiteTensor* rhs = GetInput(context, node, kInputRHSTensor);
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
const TfLiteTensor* lhs;
TF_LITE_ENSURE_OK(context,
GetInputSafe(context, node, kInputLHSTensor, &lhs));
const TfLiteTensor* rhs;
TF_LITE_ENSURE_OK(context,
GetInputSafe(context, node, kInputRHSTensor, &rhs));
TfLiteTensor* output;
TF_LITE_ENSURE_OK(context,
GetOutputSafe(context, node, kOutputTensor, &output));
RuntimeShape orig_lhs_shape = GetTensorShape(lhs);
RuntimeShape orig_rhs_shape = GetTensorShape(rhs);

View File

@ -21,6 +21,7 @@ limitations under the License.
#include "tensorflow/lite/c/builtin_op_data.h"
#include "tensorflow/lite/c/common.h"
#include "tensorflow/lite/kernels/cpu_backend_context.h"
#include "tensorflow/lite/kernels/internal/compatibility.h"
#include "tensorflow/lite/kernels/internal/kernel_utils.h"
#include "tensorflow/lite/kernels/internal/tensor_utils.h"
#include "tensorflow/lite/kernels/kernel_util.h"
@ -192,8 +193,10 @@ TfLiteStatus CheckLstmTensorDimensionsAndTypes(
TF_LITE_ENSURE(context, params->cell_clip >= 0);
TF_LITE_ENSURE(context, params->proj_clip >= 0);
const TfLiteTensor* input_to_forget_weights =
GetInput(context, node, input_to_forget_weights_tensor);
const TfLiteTensor* input_to_forget_weights;
TF_LITE_ENSURE_OK(context,
GetInputSafe(context, node, input_to_forget_weights_tensor,
&input_to_forget_weights));
TF_LITE_ENSURE_EQ(context, input_to_forget_weights->dims->size, 2);
TF_LITE_ENSURE_EQ(context, input_to_forget_weights->dims->data[0], n_cell);
TF_LITE_ENSURE_EQ(context, input_to_forget_weights->dims->data[1], n_input);
@ -211,16 +214,20 @@ TfLiteStatus CheckLstmTensorDimensionsAndTypes(
input_to_forget_weights->type);
}
const TfLiteTensor* input_to_cell_weights =
GetInput(context, node, input_to_cell_weights_tensor);
const TfLiteTensor* input_to_cell_weights;
TF_LITE_ENSURE_OK(context,
GetInputSafe(context, node, input_to_cell_weights_tensor,
&input_to_cell_weights));
TF_LITE_ENSURE_EQ(context, input_to_cell_weights->dims->size, 2);
TF_LITE_ENSURE_EQ(context, input_to_cell_weights->dims->data[0], n_cell);
TF_LITE_ENSURE_EQ(context, input_to_cell_weights->dims->data[1], n_input);
TF_LITE_ENSURE_TYPES_EQ(context, input_to_cell_weights->type,
input_to_forget_weights->type);
const TfLiteTensor* input_to_output_weights =
GetInput(context, node, input_to_output_weights_tensor);
const TfLiteTensor* input_to_output_weights;
TF_LITE_ENSURE_OK(context,
GetInputSafe(context, node, input_to_output_weights_tensor,
&input_to_output_weights));
TF_LITE_ENSURE_EQ(context, input_to_output_weights->dims->size, 2);
TF_LITE_ENSURE_EQ(context, input_to_output_weights->dims->data[0], n_cell);
TF_LITE_ENSURE_EQ(context, input_to_output_weights->dims->data[1], n_input);
@ -239,8 +246,10 @@ TfLiteStatus CheckLstmTensorDimensionsAndTypes(
input_to_forget_weights->type);
}
const TfLiteTensor* recurrent_to_forget_weights =
GetInput(context, node, recurrent_to_forget_weights_tensor);
const TfLiteTensor* recurrent_to_forget_weights;
TF_LITE_ENSURE_OK(
context, GetInputSafe(context, node, recurrent_to_forget_weights_tensor,
&recurrent_to_forget_weights));
TF_LITE_ENSURE_EQ(context, recurrent_to_forget_weights->dims->size, 2);
TF_LITE_ENSURE_EQ(context, recurrent_to_forget_weights->dims->data[0],
n_cell);
@ -249,8 +258,10 @@ TfLiteStatus CheckLstmTensorDimensionsAndTypes(
TF_LITE_ENSURE_TYPES_EQ(context, recurrent_to_forget_weights->type,
input_to_forget_weights->type);
const TfLiteTensor* recurrent_to_cell_weights =
GetInput(context, node, recurrent_to_cell_weights_tensor);
const TfLiteTensor* recurrent_to_cell_weights;
TF_LITE_ENSURE_OK(
context, GetInputSafe(context, node, recurrent_to_cell_weights_tensor,
&recurrent_to_cell_weights));
TF_LITE_ENSURE_EQ(context, recurrent_to_cell_weights->dims->size, 2);
TF_LITE_ENSURE_EQ(context, recurrent_to_cell_weights->dims->data[0], n_cell);
TF_LITE_ENSURE_EQ(context, recurrent_to_cell_weights->dims->data[1],
@ -316,20 +327,25 @@ TfLiteStatus CheckLstmTensorDimensionsAndTypes(
TF_LITE_ENSURE_TYPES_EQ(context, input_gate_bias->type, kTfLiteFloat32);
}
const TfLiteTensor* forget_gate_bias =
GetInput(context, node, forget_gate_bias_tensor);
const TfLiteTensor* forget_gate_bias;
TF_LITE_ENSURE_OK(
context,
GetInputSafe(context, node, forget_gate_bias_tensor, &forget_gate_bias));
TF_LITE_ENSURE_EQ(context, forget_gate_bias->dims->size, 1);
TF_LITE_ENSURE_EQ(context, forget_gate_bias->dims->data[0], n_cell);
TF_LITE_ENSURE_TYPES_EQ(context, forget_gate_bias->type, kTfLiteFloat32);
const TfLiteTensor* cell_gate_bias =
GetInput(context, node, cell_gate_bias_tensor);
const TfLiteTensor* cell_gate_bias;
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, cell_gate_bias_tensor,
&cell_gate_bias));
TF_LITE_ENSURE_EQ(context, cell_gate_bias->dims->size, 1);
TF_LITE_ENSURE_EQ(context, cell_gate_bias->dims->data[0], n_cell);
TF_LITE_ENSURE_EQ(context, cell_gate_bias->type, kTfLiteFloat32);
const TfLiteTensor* output_gate_bias =
GetInput(context, node, output_gate_bias_tensor);
const TfLiteTensor* output_gate_bias;
TF_LITE_ENSURE_OK(
context,
GetInputSafe(context, node, output_gate_bias_tensor, &output_gate_bias));
TF_LITE_ENSURE_EQ(context, output_gate_bias->dims->size, 1);
TF_LITE_ENSURE_EQ(context, output_gate_bias->dims->data[0], n_cell);
TF_LITE_ENSURE_TYPES_EQ(context, output_gate_bias->type, kTfLiteFloat32);
@ -413,7 +429,8 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
// Inferring batch size, number of outputs and sequence length and
// number of cells from the input tensors.
const TfLiteTensor* input = GetInput(context, node, kInputTensor);
const TfLiteTensor* input;
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputTensor, &input));
TF_LITE_ENSURE_TYPES_EQ(context, input->type, kTfLiteFloat32);
TF_LITE_ENSURE_EQ(context, input->dims->size, 3);
const bool time_major = params->time_major;
@ -421,15 +438,19 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
const int n_batch = time_major ? input->dims->data[1] : input->dims->data[0];
const int n_input = input->dims->data[2];
const TfLiteTensor* fw_input_to_output_weights =
GetInput(context, node, kFwInputToOutputWeightsTensor);
const TfLiteTensor* fw_input_to_output_weights;
TF_LITE_ENSURE_OK(context,
GetInputSafe(context, node, kFwInputToOutputWeightsTensor,
&fw_input_to_output_weights));
const int n_fw_cell = fw_input_to_output_weights->dims->data[0];
TF_LITE_ENSURE_EQ(context, fw_input_to_output_weights->dims->size, 2);
TF_LITE_ENSURE_EQ(context, fw_input_to_output_weights->dims->data[1],
n_input);
const TfLiteTensor* bw_input_to_output_weights =
GetInput(context, node, kBwInputToOutputWeightsTensor);
const TfLiteTensor* bw_input_to_output_weights;
TF_LITE_ENSURE_OK(context,
GetInputSafe(context, node, kBwInputToOutputWeightsTensor,
&bw_input_to_output_weights));
const int n_bw_cell = bw_input_to_output_weights->dims->data[0];
TF_LITE_ENSURE_EQ(context, bw_input_to_output_weights->dims->size, 2);
TF_LITE_ENSURE_EQ(context, bw_input_to_output_weights->dims->data[1],
@ -437,8 +458,10 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_ENSURE_EQ(context, bw_input_to_output_weights->type,
fw_input_to_output_weights->type);
const TfLiteTensor* fw_recurrent_to_output_weights =
GetInput(context, node, kFwRecurrentToOutputWeightsTensor);
const TfLiteTensor* fw_recurrent_to_output_weights;
TF_LITE_ENSURE_OK(
context, GetInputSafe(context, node, kFwRecurrentToOutputWeightsTensor,
&fw_recurrent_to_output_weights));
TF_LITE_ENSURE_EQ(context, fw_recurrent_to_output_weights->dims->size, 2);
TF_LITE_ENSURE_EQ(context, fw_recurrent_to_output_weights->dims->data[0],
n_fw_cell);
@ -446,8 +469,10 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
fw_input_to_output_weights->type);
const int n_fw_output = fw_recurrent_to_output_weights->dims->data[1];
const TfLiteTensor* bw_recurrent_to_output_weights =
GetInput(context, node, kBwRecurrentToOutputWeightsTensor);
const TfLiteTensor* bw_recurrent_to_output_weights;
TF_LITE_ENSURE_OK(
context, GetInputSafe(context, node, kBwRecurrentToOutputWeightsTensor,
&bw_recurrent_to_output_weights));
TF_LITE_ENSURE_EQ(context, bw_recurrent_to_output_weights->dims->size, 2);
TF_LITE_ENSURE_EQ(context, bw_recurrent_to_output_weights->dims->data[0],
n_bw_cell);
@ -504,7 +529,9 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
}
// Get the pointer to output, activation_state and cell_state buffer tensors.
TfLiteTensor* fw_output = GetOutput(context, node, kFwOutputTensor);
TfLiteTensor* fw_output;
TF_LITE_ENSURE_OK(context,
GetOutputSafe(context, node, kFwOutputTensor, &fw_output));
TfLiteTensor* fw_activation_state =
GetVariableInput(context, node, kFwInputActivationStateTensor);
TF_LITE_ENSURE(context, fw_activation_state != nullptr);
@ -541,8 +568,9 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
// Create a scratch buffer tensor.
node->temporaries->data[kFwScratchBuffer] =
op_data->scratch_tensor_index + kFwScratchBuffer;
TfLiteTensor* fw_scratch_buffer =
GetTemporary(context, node, kFwScratchBuffer);
TfLiteTensor* fw_scratch_buffer;
TF_LITE_ENSURE_OK(context, GetTemporarySafe(context, node, kFwScratchBuffer,
&fw_scratch_buffer));
fw_scratch_buffer->type = input->type;
fw_scratch_buffer->allocation_type = kTfLiteArenaRw;
@ -581,7 +609,9 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
// Resize the output tensors.
if (!params->merge_outputs) {
TfLiteTensor* bw_output = GetOutput(context, node, kBwOutputTensor);
TfLiteTensor* bw_output;
TF_LITE_ENSURE_OK(
context, GetOutputSafe(context, node, kBwOutputTensor, &bw_output));
TfLiteIntArray* bw_output_size = TfLiteIntArrayCreate(3);
bw_output_size->data[0] = time_major ? max_time : n_batch;
bw_output_size->data[1] = time_major ? n_batch : max_time;
@ -600,8 +630,9 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
// Create a scratch buffer tensor.
node->temporaries->data[kBwScratchBuffer] =
op_data->scratch_tensor_index + kBwScratchBuffer;
TfLiteTensor* bw_scratch_buffer =
GetTemporary(context, node, kBwScratchBuffer);
TfLiteTensor* bw_scratch_buffer;
TF_LITE_ENSURE_OK(context, GetTemporarySafe(context, node, kBwScratchBuffer,
&bw_scratch_buffer));
bw_scratch_buffer->type = input->type;
bw_scratch_buffer->allocation_type = kTfLiteArenaRw;
@ -631,8 +662,9 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
// (if present), activation_state and cell_state tensors.
node->temporaries->data[kInputQuantized] =
op_data->scratch_tensor_index + kInputQuantized;
TfLiteTensor* input_quantized =
GetTemporary(context, node, kInputQuantized);
TfLiteTensor* input_quantized;
TF_LITE_ENSURE_OK(context, GetTemporarySafe(context, node, kInputQuantized,
&input_quantized));
input_quantized->type = fw_input_to_output_weights->type;
input_quantized->allocation_type = kTfLiteArenaRw;
if (!TfLiteIntArrayEqual(input_quantized->dims, input->dims)) {
@ -643,8 +675,10 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
node->temporaries->data[kFwActivationStateQuantized] =
op_data->scratch_tensor_index + kFwActivationStateQuantized;
TfLiteTensor* fw_activation_state_quantized =
GetTemporary(context, node, kFwActivationStateQuantized);
TfLiteTensor* fw_activation_state_quantized;
TF_LITE_ENSURE_OK(
context, GetTemporarySafe(context, node, kFwActivationStateQuantized,
&fw_activation_state_quantized));
fw_activation_state_quantized->type = fw_input_to_output_weights->type;
fw_activation_state_quantized->allocation_type = kTfLiteArenaRw;
if (!TfLiteIntArrayEqual(fw_activation_state_quantized->dims,
@ -657,8 +691,10 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
}
node->temporaries->data[kBwActivationStateQuantized] =
op_data->scratch_tensor_index + kBwActivationStateQuantized;
TfLiteTensor* bw_activation_state_quantized =
GetTemporary(context, node, kBwActivationStateQuantized);
TfLiteTensor* bw_activation_state_quantized;
TF_LITE_ENSURE_OK(
context, GetTemporarySafe(context, node, kBwActivationStateQuantized,
&bw_activation_state_quantized));
bw_activation_state_quantized->type = fw_input_to_output_weights->type;
bw_activation_state_quantized->allocation_type = kTfLiteArenaRw;
if (!TfLiteIntArrayEqual(bw_activation_state_quantized->dims,
@ -671,8 +707,10 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
}
node->temporaries->data[kFwCellStateQuantized] =
op_data->scratch_tensor_index + kFwCellStateQuantized;
TfLiteTensor* fw_cell_state_quantized =
GetTemporary(context, node, kFwCellStateQuantized);
TfLiteTensor* fw_cell_state_quantized;
TF_LITE_ENSURE_OK(context,
GetTemporarySafe(context, node, kFwCellStateQuantized,
&fw_cell_state_quantized));
fw_cell_state_quantized->type = fw_input_to_output_weights->type;
fw_cell_state_quantized->allocation_type = kTfLiteArenaRw;
if (!TfLiteIntArrayEqual(fw_cell_state_quantized->dims,
@ -685,8 +723,10 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
}
node->temporaries->data[kBwCellStateQuantized] =
op_data->scratch_tensor_index + kBwCellStateQuantized;
TfLiteTensor* bw_cell_state_quantized =
GetTemporary(context, node, kBwCellStateQuantized);
TfLiteTensor* bw_cell_state_quantized;
TF_LITE_ENSURE_OK(context,
GetTemporarySafe(context, node, kBwCellStateQuantized,
&bw_cell_state_quantized));
bw_cell_state_quantized->type = fw_input_to_output_weights->type;
bw_cell_state_quantized->allocation_type = kTfLiteArenaRw;
if (!TfLiteIntArrayEqual(bw_cell_state_quantized->dims,
@ -705,7 +745,10 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
// the scaling factor of the matrix).
node->temporaries->data[kInputScalingFactors] =
op_data->scratch_tensor_index + kInputScalingFactors;
TfLiteTensor* input_sf = GetTemporary(context, node, kInputScalingFactors);
TfLiteTensor* input_sf;
TF_LITE_ENSURE_OK(
context,
GetTemporarySafe(context, node, kInputScalingFactors, &input_sf));
input_sf->type = kTfLiteFloat32;
input_sf->allocation_type = kTfLiteArenaRw;
int scaling_dims[1] = {n_batch};
@ -717,8 +760,10 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
}
node->temporaries->data[kAuxInputScalingFactors] =
op_data->scratch_tensor_index + kAuxInputScalingFactors;
TfLiteTensor* aux_input_sf =
GetTemporary(context, node, kAuxInputScalingFactors);
TfLiteTensor* aux_input_sf;
TF_LITE_ENSURE_OK(context,
GetTemporarySafe(context, node, kAuxInputScalingFactors,
&aux_input_sf));
aux_input_sf->type = kTfLiteFloat32;
aux_input_sf->allocation_type = kTfLiteArenaRw;
if (!TfLiteIntArrayEqualsArray(aux_input_sf->dims, 1, scaling_dims)) {
@ -729,8 +774,10 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
}
node->temporaries->data[kOutputStateScalingFactors] =
op_data->scratch_tensor_index + kOutputStateScalingFactors;
TfLiteTensor* output_state_sf =
GetTemporary(context, node, kOutputStateScalingFactors);
TfLiteTensor* output_state_sf;
TF_LITE_ENSURE_OK(
context, GetTemporarySafe(context, node, kOutputStateScalingFactors,
&output_state_sf));
output_state_sf->type = kTfLiteFloat32;
output_state_sf->allocation_type = kTfLiteArenaRw;
if (!TfLiteIntArrayEqualsArray(output_state_sf->dims, 1, scaling_dims)) {
@ -741,8 +788,10 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
}
node->temporaries->data[kProductScalingFactors] =
op_data->scratch_tensor_index + kProductScalingFactors;
TfLiteTensor* prod_scaling_factors =
GetTemporary(context, node, kProductScalingFactors);
TfLiteTensor* prod_scaling_factors;
TF_LITE_ENSURE_OK(context,
GetTemporarySafe(context, node, kProductScalingFactors,
&prod_scaling_factors));
prod_scaling_factors->type = kTfLiteFloat32;
prod_scaling_factors->allocation_type = kTfLiteArenaRw;
if (!TfLiteIntArrayEqualsArray(prod_scaling_factors->dims, 1,
@ -758,8 +807,10 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
// this is used for diagonal matrices, only need to store n_cell values.
node->temporaries->data[kRecoveredCellWeights] =
op_data->scratch_tensor_index + kRecoveredCellWeights;
TfLiteTensor* recovered_cell_weights =
GetTemporary(context, node, kRecoveredCellWeights);
TfLiteTensor* recovered_cell_weights;
TF_LITE_ENSURE_OK(context,
GetTemporarySafe(context, node, kRecoveredCellWeights,
&recovered_cell_weights));
recovered_cell_weights->type = kTfLiteFloat32;
recovered_cell_weights->allocation_type = kTfLiteArenaRw;
int recovered_cell_dims[1] = {n_fw_cell};
@ -775,8 +826,10 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
// Allocate a temporary tensor to store the accumulated int32 values.
node->temporaries->data[kAccumScratchBuffer] =
op_data->scratch_tensor_index + kAccumScratchBuffer;
TfLiteTensor* accum_scratch =
GetTemporary(context, node, kAccumScratchBuffer);
TfLiteTensor* accum_scratch;
TF_LITE_ENSURE_OK(
context,
GetTemporarySafe(context, node, kAccumScratchBuffer, &accum_scratch));
accum_scratch->type = kTfLiteInt32;
accum_scratch->allocation_type = kTfLiteArenaRw;
int n_cell = std::max(n_fw_cell, n_bw_cell);
@ -797,7 +850,9 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
// Allocate temporary tensors for storing zero-points.
node->temporaries->data[kInputZeroPoints] =
op_data->scratch_tensor_index + kInputZeroPoints;
TfLiteTensor* input_zp = GetTemporary(context, node, kInputZeroPoints);
TfLiteTensor* input_zp;
TF_LITE_ENSURE_OK(
context, GetTemporarySafe(context, node, kInputZeroPoints, &input_zp));
input_zp->type = kTfLiteFloat32;
input_zp->allocation_type = kTfLiteArenaRw;
if (!TfLiteIntArrayEqualsArray(input_zp->dims, 1, scaling_dims)) {
@ -808,8 +863,10 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
}
node->temporaries->data[kAuxInputZeroPoints] =
op_data->scratch_tensor_index + kAuxInputZeroPoints;
TfLiteTensor* aux_input_zp =
GetTemporary(context, node, kAuxInputZeroPoints);
TfLiteTensor* aux_input_zp;
TF_LITE_ENSURE_OK(
context,
GetTemporarySafe(context, node, kAuxInputZeroPoints, &aux_input_zp));
aux_input_zp->type = kTfLiteFloat32;
aux_input_zp->allocation_type = kTfLiteArenaRw;
if (!TfLiteIntArrayEqualsArray(aux_input_zp->dims, 1, scaling_dims)) {
@ -820,8 +877,10 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
}
node->temporaries->data[kOutputStateZeroPoints] =
op_data->scratch_tensor_index + kOutputStateZeroPoints;
TfLiteTensor* output_state_zp =
GetTemporary(context, node, kOutputStateZeroPoints);
TfLiteTensor* output_state_zp;
TF_LITE_ENSURE_OK(context,
GetTemporarySafe(context, node, kOutputStateZeroPoints,
&output_state_zp));
output_state_zp->type = kTfLiteFloat32;
output_state_zp->allocation_type = kTfLiteArenaRw;
if (!TfLiteIntArrayEqualsArray(output_state_zp->dims, 1, scaling_dims)) {
@ -844,7 +903,9 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
}
node->temporaries->data[kFwRowSums] =
op_data->scratch_tensor_index + kFwRowSums;
TfLiteTensor* fw_row_sums = GetTemporary(context, node, kFwRowSums);
TfLiteTensor* fw_row_sums;
TF_LITE_ENSURE_OK(
context, GetTemporarySafe(context, node, kFwRowSums, &fw_row_sums));
fw_row_sums->type = kTfLiteInt32;
fw_row_sums->allocation_type = kTfLiteArenaRwPersistent;
int fw_row_sums_dims[2] = {fw_row_sums_rows, n_fw_cell};
@ -867,7 +928,9 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
}
node->temporaries->data[kBwRowSums] =
op_data->scratch_tensor_index + kBwRowSums;
TfLiteTensor* bw_row_sums = GetTemporary(context, node, kBwRowSums);
TfLiteTensor* bw_row_sums;
TF_LITE_ENSURE_OK(
context, GetTemporarySafe(context, node, kBwRowSums, &bw_row_sums));
bw_row_sums->type = kTfLiteInt32;
bw_row_sums->allocation_type = kTfLiteArenaRwPersistent;
int bw_row_sums_dims[2] = {bw_row_sums_rows, n_bw_cell};
@ -884,8 +947,10 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
if (has_aux_input) {
node->temporaries->data[kAuxInputQuantized] =
op_data->scratch_tensor_index + kAuxInputQuantized;
TfLiteTensor* aux_input_quantized =
GetTemporary(context, node, kAuxInputQuantized);
TfLiteTensor* aux_input_quantized;
TF_LITE_ENSURE_OK(context,
GetTemporarySafe(context, node, kAuxInputQuantized,
&aux_input_quantized));
aux_input_quantized->type = fw_input_to_output_weights->type;
aux_input_quantized->allocation_type = kTfLiteArenaRw;
if (!TfLiteIntArrayEqual(aux_input_quantized->dims, aux_input->dims)) {
@ -906,26 +971,39 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
node->builtin_data);
auto* op_data = reinterpret_cast<OpData*>(node->user_data);
// Input tensor.
const TfLiteTensor* input = GetInput(context, node, kInputTensor);
const TfLiteTensor* input;
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputTensor, &input));
// Tensors for the forward cell.
const TfLiteTensor* fw_input_to_input_weights =
GetOptionalInputTensor(context, node, kFwInputToInputWeightsTensor);
const TfLiteTensor* fw_input_to_forget_weights =
GetInput(context, node, kFwInputToForgetWeightsTensor);
const TfLiteTensor* fw_input_to_cell_weights =
GetInput(context, node, kFwInputToCellWeightsTensor);
const TfLiteTensor* fw_input_to_output_weights =
GetInput(context, node, kFwInputToOutputWeightsTensor);
const TfLiteTensor* fw_input_to_forget_weights;
TF_LITE_ENSURE_OK(context,
GetInputSafe(context, node, kFwInputToForgetWeightsTensor,
&fw_input_to_forget_weights));
const TfLiteTensor* fw_input_to_cell_weights;
TF_LITE_ENSURE_OK(context,
GetInputSafe(context, node, kFwInputToCellWeightsTensor,
&fw_input_to_cell_weights));
const TfLiteTensor* fw_input_to_output_weights;
TF_LITE_ENSURE_OK(context,
GetInputSafe(context, node, kFwInputToOutputWeightsTensor,
&fw_input_to_output_weights));
const TfLiteTensor* fw_recurrent_to_input_weights =
GetOptionalInputTensor(context, node, kFwRecurrentToInputWeightsTensor);
const TfLiteTensor* fw_recurrent_to_forget_weights =
GetInput(context, node, kFwRecurrentToForgetWeightsTensor);
const TfLiteTensor* fw_recurrent_to_cell_weights =
GetInput(context, node, kFwRecurrentToCellWeightsTensor);
const TfLiteTensor* fw_recurrent_to_output_weights =
GetInput(context, node, kFwRecurrentToOutputWeightsTensor);
const TfLiteTensor* fw_recurrent_to_forget_weights;
TF_LITE_ENSURE_OK(
context, GetInputSafe(context, node, kFwRecurrentToForgetWeightsTensor,
&fw_recurrent_to_forget_weights));
const TfLiteTensor* fw_recurrent_to_cell_weights;
TF_LITE_ENSURE_OK(context,
GetInputSafe(context, node, kFwRecurrentToCellWeightsTensor,
&fw_recurrent_to_cell_weights));
const TfLiteTensor* fw_recurrent_to_output_weights;
TF_LITE_ENSURE_OK(
context, GetInputSafe(context, node, kFwRecurrentToOutputWeightsTensor,
&fw_recurrent_to_output_weights));
const TfLiteTensor* fw_cell_to_input_weights =
GetOptionalInputTensor(context, node, kFwCellToInputWeightsTensor);
@ -936,12 +1014,17 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
const TfLiteTensor* fw_input_gate_bias =
GetOptionalInputTensor(context, node, kFwInputGateBiasTensor);
const TfLiteTensor* fw_forget_gate_bias =
GetInput(context, node, kFwForgetGateBiasTensor);
const TfLiteTensor* fw_cell_gate_bias =
GetInput(context, node, kFwCellGateBiasTensor);
const TfLiteTensor* fw_output_gate_bias =
GetInput(context, node, kFwOutputGateBiasTensor);
const TfLiteTensor* fw_forget_gate_bias;
TF_LITE_ENSURE_OK(context,
GetInputSafe(context, node, kFwForgetGateBiasTensor,
&fw_forget_gate_bias));
const TfLiteTensor* fw_cell_gate_bias;
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kFwCellGateBiasTensor,
&fw_cell_gate_bias));
const TfLiteTensor* fw_output_gate_bias;
TF_LITE_ENSURE_OK(context,
GetInputSafe(context, node, kFwOutputGateBiasTensor,
&fw_output_gate_bias));
const TfLiteTensor* fw_projection_weights =
GetOptionalInputTensor(context, node, kFwProjectionWeightsTensor);
@ -950,30 +1033,44 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
TfLiteTensor* fw_activation_state =
GetVariableInput(context, node, kFwInputActivationStateTensor);
TF_LITE_ENSURE(context, fw_activation_state != nullptr);
TFLITE_DCHECK(fw_activation_state != nullptr);
TfLiteTensor* fw_cell_state =
GetVariableInput(context, node, kFwInputCellStateTensor);
TF_LITE_ENSURE(context, fw_cell_state != nullptr);
TfLiteTensor* fw_output = GetOutput(context, node, kFwOutputTensor);
TFLITE_DCHECK(fw_cell_state != nullptr);
TfLiteTensor* fw_output;
TF_LITE_ENSURE_OK(context,
GetOutputSafe(context, node, kFwOutputTensor, &fw_output));
// Tensors for the backward cell.
const TfLiteTensor* bw_input_to_input_weights =
GetOptionalInputTensor(context, node, kBwInputToInputWeightsTensor);
const TfLiteTensor* bw_input_to_forget_weights =
GetInput(context, node, kBwInputToForgetWeightsTensor);
const TfLiteTensor* bw_input_to_cell_weights =
GetInput(context, node, kBwInputToCellWeightsTensor);
const TfLiteTensor* bw_input_to_output_weights =
GetInput(context, node, kBwInputToOutputWeightsTensor);
const TfLiteTensor* bw_input_to_forget_weights;
TF_LITE_ENSURE_OK(context,
GetInputSafe(context, node, kBwInputToForgetWeightsTensor,
&bw_input_to_forget_weights));
const TfLiteTensor* bw_input_to_cell_weights;
TF_LITE_ENSURE_OK(context,
GetInputSafe(context, node, kBwInputToCellWeightsTensor,
&bw_input_to_cell_weights));
const TfLiteTensor* bw_input_to_output_weights;
TF_LITE_ENSURE_OK(context,
GetInputSafe(context, node, kBwInputToOutputWeightsTensor,
&bw_input_to_output_weights));
const TfLiteTensor* bw_recurrent_to_input_weights =
GetOptionalInputTensor(context, node, kBwRecurrentToInputWeightsTensor);
const TfLiteTensor* bw_recurrent_to_forget_weights =
GetInput(context, node, kBwRecurrentToForgetWeightsTensor);
const TfLiteTensor* bw_recurrent_to_cell_weights =
GetInput(context, node, kBwRecurrentToCellWeightsTensor);
const TfLiteTensor* bw_recurrent_to_output_weights =
GetInput(context, node, kBwRecurrentToOutputWeightsTensor);
const TfLiteTensor* bw_recurrent_to_forget_weights;
TF_LITE_ENSURE_OK(
context, GetInputSafe(context, node, kBwRecurrentToForgetWeightsTensor,
&bw_recurrent_to_forget_weights));
const TfLiteTensor* bw_recurrent_to_cell_weights;
TF_LITE_ENSURE_OK(context,
GetInputSafe(context, node, kBwRecurrentToCellWeightsTensor,
&bw_recurrent_to_cell_weights));
const TfLiteTensor* bw_recurrent_to_output_weights;
TF_LITE_ENSURE_OK(
context, GetInputSafe(context, node, kBwRecurrentToOutputWeightsTensor,
&bw_recurrent_to_output_weights));
const TfLiteTensor* bw_cell_to_input_weights =
GetOptionalInputTensor(context, node, kBwCellToInputWeightsTensor);
@ -984,12 +1081,17 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
const TfLiteTensor* bw_input_gate_bias =
GetOptionalInputTensor(context, node, kBwInputGateBiasTensor);
const TfLiteTensor* bw_forget_gate_bias =
GetInput(context, node, kBwForgetGateBiasTensor);
const TfLiteTensor* bw_cell_gate_bias =
GetInput(context, node, kBwCellGateBiasTensor);
const TfLiteTensor* bw_output_gate_bias =
GetInput(context, node, kBwOutputGateBiasTensor);
const TfLiteTensor* bw_forget_gate_bias;
TF_LITE_ENSURE_OK(context,
GetInputSafe(context, node, kBwForgetGateBiasTensor,
&bw_forget_gate_bias));
const TfLiteTensor* bw_cell_gate_bias;
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kBwCellGateBiasTensor,
&bw_cell_gate_bias));
const TfLiteTensor* bw_output_gate_bias;
TF_LITE_ENSURE_OK(context,
GetInputSafe(context, node, kBwOutputGateBiasTensor,
&bw_output_gate_bias));
const TfLiteTensor* bw_projection_weights =
GetOptionalInputTensor(context, node, kBwProjectionWeightsTensor);
@ -999,19 +1101,21 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
// State tensors.
TfLiteTensor* bw_activation_state =
GetVariableInput(context, node, kBwInputActivationStateTensor);
TF_LITE_ENSURE(context, bw_activation_state != nullptr);
TFLITE_DCHECK(bw_activation_state != nullptr);
TfLiteTensor* bw_cell_state =
GetVariableInput(context, node, kBwInputCellStateTensor);
TF_LITE_ENSURE(context, bw_cell_state != nullptr);
TFLITE_DCHECK(bw_cell_state != nullptr);
TfLiteTensor* bw_output = params->merge_outputs
? nullptr
: GetOutput(context, node, kBwOutputTensor);
// Temporary tensors.
TfLiteTensor* fw_scratch_buffer =
GetTemporary(context, node, kFwScratchBuffer);
TfLiteTensor* bw_scratch_buffer =
GetTemporary(context, node, kBwScratchBuffer);
TfLiteTensor* fw_scratch_buffer;
TF_LITE_ENSURE_OK(context, GetTemporarySafe(context, node, kFwScratchBuffer,
&fw_scratch_buffer));
TfLiteTensor* bw_scratch_buffer;
TF_LITE_ENSURE_OK(context, GetTemporarySafe(context, node, kBwScratchBuffer,
&bw_scratch_buffer));
// (Optional) auxiliary inputs.
const TfLiteTensor* aux_input =
@ -1112,27 +1216,47 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
}
case kTfLiteUInt8:
case kTfLiteInt8: {
TfLiteTensor* input_quantized =
GetTemporary(context, node, kInputQuantized);
TfLiteTensor* fw_activation_state_quantized =
GetTemporary(context, node, kFwActivationStateQuantized);
TfLiteTensor* bw_activation_state_quantized =
GetTemporary(context, node, kBwActivationStateQuantized);
TfLiteTensor* fw_cell_state_quantized =
GetTemporary(context, node, kFwCellStateQuantized);
TfLiteTensor* bw_cell_state_quantized =
GetTemporary(context, node, kBwCellStateQuantized);
TfLiteTensor* prod_scaling_factors =
GetTemporary(context, node, kProductScalingFactors);
TfLiteTensor* recovered_cell_weights =
GetTemporary(context, node, kRecoveredCellWeights);
TfLiteTensor* input_quantized;
TF_LITE_ENSURE_OK(
context,
GetTemporarySafe(context, node, kInputQuantized, &input_quantized));
TfLiteTensor* fw_activation_state_quantized;
TF_LITE_ENSURE_OK(
context, GetTemporarySafe(context, node, kFwActivationStateQuantized,
&fw_activation_state_quantized));
TfLiteTensor* bw_activation_state_quantized;
TF_LITE_ENSURE_OK(
context, GetTemporarySafe(context, node, kBwActivationStateQuantized,
&bw_activation_state_quantized));
TfLiteTensor* fw_cell_state_quantized;
TF_LITE_ENSURE_OK(context,
GetTemporarySafe(context, node, kFwCellStateQuantized,
&fw_cell_state_quantized));
TfLiteTensor* bw_cell_state_quantized;
TF_LITE_ENSURE_OK(context,
GetTemporarySafe(context, node, kBwCellStateQuantized,
&bw_cell_state_quantized));
TfLiteTensor* prod_scaling_factors;
TF_LITE_ENSURE_OK(context,
GetTemporarySafe(context, node, kProductScalingFactors,
&prod_scaling_factors));
TfLiteTensor* recovered_cell_weights;
TF_LITE_ENSURE_OK(context,
GetTemporarySafe(context, node, kRecoveredCellWeights,
&recovered_cell_weights));
TfLiteTensor* aux_input_quantized =
use_aux_input ? GetTemporary(context, node, kAuxInputQuantized)
: nullptr;
TfLiteTensor* accum_scratch =
GetTemporary(context, node, kAccumScratchBuffer);
TfLiteTensor* fw_row_sums = GetTemporary(context, node, kFwRowSums);
TfLiteTensor* bw_row_sums = GetTemporary(context, node, kBwRowSums);
TfLiteTensor* accum_scratch;
TF_LITE_ENSURE_OK(
context,
GetTemporarySafe(context, node, kAccumScratchBuffer, &accum_scratch));
TfLiteTensor* fw_row_sums;
TF_LITE_ENSURE_OK(
context, GetTemporarySafe(context, node, kFwRowSums, &fw_row_sums));
TfLiteTensor* bw_row_sums;
TF_LITE_ENSURE_OK(
context, GetTemporarySafe(context, node, kBwRowSums, &bw_row_sums));
const int fw_row_sums_size = fw_row_sums->dims->data[0];
const int bw_row_sums_size = bw_row_sums->dims->data[0];
TfLiteStatus fw_pass_status = lstm_eval::EvalHybrid(

View File

@ -97,21 +97,34 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_ENSURE_EQ(context, node->outputs->size,
params->merge_outputs ? 1 : 2);
const TfLiteTensor* input = GetInput(context, node, kInputTensor);
const TfLiteTensor* fw_input_weights =
GetInput(context, node, kFwWeightsTensor);
const TfLiteTensor* fw_recurrent_weights =
GetInput(context, node, kFwRecurrentWeightsTensor);
const TfLiteTensor* fw_bias = GetInput(context, node, kFwBiasTensor);
const TfLiteTensor* fw_hidden_state =
GetInput(context, node, kFwHiddenStateTensor);
const TfLiteTensor* bw_input_weights =
GetInput(context, node, kBwWeightsTensor);
const TfLiteTensor* bw_recurrent_weights =
GetInput(context, node, kBwRecurrentWeightsTensor);
const TfLiteTensor* bw_bias = GetInput(context, node, kBwBiasTensor);
const TfLiteTensor* bw_hidden_state =
GetInput(context, node, kBwHiddenStateTensor);
const TfLiteTensor* input;
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputTensor, &input));
const TfLiteTensor* fw_input_weights;
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kFwWeightsTensor,
&fw_input_weights));
const TfLiteTensor* fw_recurrent_weights;
TF_LITE_ENSURE_OK(context,
GetInputSafe(context, node, kFwRecurrentWeightsTensor,
&fw_recurrent_weights));
const TfLiteTensor* fw_bias;
TF_LITE_ENSURE_OK(context,
GetInputSafe(context, node, kFwBiasTensor, &fw_bias));
const TfLiteTensor* fw_hidden_state;
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kFwHiddenStateTensor,
&fw_hidden_state));
const TfLiteTensor* bw_input_weights;
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kBwWeightsTensor,
&bw_input_weights));
const TfLiteTensor* bw_recurrent_weights;
TF_LITE_ENSURE_OK(context,
GetInputSafe(context, node, kBwRecurrentWeightsTensor,
&bw_recurrent_weights));
const TfLiteTensor* bw_bias;
TF_LITE_ENSURE_OK(context,
GetInputSafe(context, node, kBwBiasTensor, &bw_bias));
const TfLiteTensor* bw_hidden_state;
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kBwHiddenStateTensor,
&bw_hidden_state));
const TfLiteTensor* aux_input =
GetOptionalInputTensor(context, node, kAuxInputTensor);
@ -186,8 +199,9 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
node->temporaries->data[kInputQuantized] =
op_data->scratch_tensor_index + kInputQuantized;
TfLiteTensor* input_quantized =
GetTemporary(context, node, kInputQuantized);
TfLiteTensor* input_quantized;
TF_LITE_ENSURE_OK(context, GetTemporarySafe(context, node, kInputQuantized,
&input_quantized));
input_quantized->type = fw_input_weights->type;
input_quantized->allocation_type = kTfLiteArenaRw;
if (!TfLiteIntArrayEqual(input_quantized->dims, input->dims)) {
@ -198,8 +212,10 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
node->temporaries->data[kFwHiddenStateQuantized] =
op_data->scratch_tensor_index + kFwHiddenStateQuantized;
TfLiteTensor* fw_hidden_state_quantized =
GetTemporary(context, node, kFwHiddenStateQuantized);
TfLiteTensor* fw_hidden_state_quantized;
TF_LITE_ENSURE_OK(context,
GetTemporarySafe(context, node, kFwHiddenStateQuantized,
&fw_hidden_state_quantized));
fw_hidden_state_quantized->type = fw_input_weights->type;
fw_hidden_state_quantized->allocation_type = kTfLiteArenaRw;
if (!TfLiteIntArrayEqual(fw_hidden_state_quantized->dims,
@ -213,8 +229,10 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
node->temporaries->data[kBwHiddenStateQuantized] =
op_data->scratch_tensor_index + kBwHiddenStateQuantized;
TfLiteTensor* bw_hidden_state_quantized =
GetTemporary(context, node, kBwHiddenStateQuantized);
TfLiteTensor* bw_hidden_state_quantized;
TF_LITE_ENSURE_OK(context,
GetTemporarySafe(context, node, kBwHiddenStateQuantized,
&bw_hidden_state_quantized));
bw_hidden_state_quantized->type = fw_input_weights->type;
bw_hidden_state_quantized->allocation_type = kTfLiteArenaRw;
if (!TfLiteIntArrayEqual(bw_hidden_state_quantized->dims,
@ -229,8 +247,9 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
// Allocate temporary tensors to store scaling factors of quantization.
node->temporaries->data[kScalingFactors] =
op_data->scratch_tensor_index + kScalingFactors;
TfLiteTensor* scaling_factors =
GetTemporary(context, node, kScalingFactors);
TfLiteTensor* scaling_factors;
TF_LITE_ENSURE_OK(context, GetTemporarySafe(context, node, kScalingFactors,
&scaling_factors));
scaling_factors->type = kTfLiteFloat32;
scaling_factors->allocation_type = kTfLiteArenaRw;
int scaling_dims[1] = {batch_size};
@ -242,7 +261,9 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
}
node->temporaries->data[kAccumScratch] =
op_data->scratch_tensor_index + kAccumScratch;
TfLiteTensor* accum_scratch = GetTemporary(context, node, kAccumScratch);
TfLiteTensor* accum_scratch;
TF_LITE_ENSURE_OK(context, GetTemporarySafe(context, node, kAccumScratch,
&accum_scratch));
accum_scratch->type = kTfLiteInt32;
accum_scratch->allocation_type = kTfLiteArenaRw;
int accum_scratch_dims[2] = {std::max(fw_num_units, bw_num_units),
@ -257,8 +278,10 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
}
node->temporaries->data[kZeroPoints] =
op_data->scratch_tensor_index + kZeroPoints;
TfLiteTensor* zero_points =
GetTemporary(context, node, /*index=*/kZeroPoints);
TfLiteTensor* zero_points;
TF_LITE_ENSURE_OK(
context,
GetTemporarySafe(context, node, /*index=*/kZeroPoints, &zero_points));
zero_points->type = kTfLiteInt32;
zero_points->allocation_type = kTfLiteArenaRw;
int zero_points_dims[1] = {batch_size};
@ -271,8 +294,10 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
const int num_row_sums = has_aux_input ? 3 : 2;
node->temporaries->data[kFwRowSums] =
op_data->scratch_tensor_index + kFwRowSums;
TfLiteTensor* fw_row_sums =
GetTemporary(context, node, /*index=*/kFwRowSums);
TfLiteTensor* fw_row_sums;
TF_LITE_ENSURE_OK(
context,
GetTemporarySafe(context, node, /*index=*/kFwRowSums, &fw_row_sums));
fw_row_sums->type = kTfLiteInt32;
fw_row_sums->allocation_type = kTfLiteArenaRwPersistent;
int fw_row_sums_dims[2] = {num_row_sums, fw_num_units};
@ -285,8 +310,10 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
}
node->temporaries->data[kBwRowSums] =
op_data->scratch_tensor_index + kBwRowSums;
TfLiteTensor* bw_row_sums = GetTemporary(context, node,
/*index=*/kBwRowSums);
TfLiteTensor* bw_row_sums;
TF_LITE_ENSURE_OK(
context,
GetTemporarySafe(context, node, /*index=*/kBwRowSums, &bw_row_sums));
bw_row_sums->type = kTfLiteInt32;
bw_row_sums->allocation_type = kTfLiteArenaRwPersistent;
int bw_row_sums_dims[2] = {num_row_sums, bw_num_units};
@ -300,8 +327,10 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
if (has_aux_input) {
node->temporaries->data[kAuxInputQuantized] =
op_data->scratch_tensor_index + kAuxInputQuantized;
TfLiteTensor* aux_input_quantized =
GetTemporary(context, node, kAuxInputQuantized);
TfLiteTensor* aux_input_quantized;
TF_LITE_ENSURE_OK(context,
GetTemporarySafe(context, node, kAuxInputQuantized,
&aux_input_quantized));
aux_input_quantized->type = fw_input_weights->type;
aux_input_quantized->allocation_type = kTfLiteArenaRw;
if (!TfLiteIntArrayEqual(aux_input_quantized->dims, aux_input->dims)) {
@ -315,7 +344,9 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
}
// Resize outputs.
TfLiteTensor* fw_output = GetOutput(context, node, kFwOutputTensor);
TfLiteTensor* fw_output;
TF_LITE_ENSURE_OK(context,
GetOutputSafe(context, node, kFwOutputTensor, &fw_output));
TfLiteIntArray* fw_output_size_array = TfLiteIntArrayCreate(3);
fw_output_size_array->data[0] = (time_major) ? max_time : batch_size;
fw_output_size_array->data[1] = (time_major) ? batch_size : max_time;
@ -324,7 +355,9 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_ENSURE_OK(
context, context->ResizeTensor(context, fw_output, fw_output_size_array));
if (!params->merge_outputs) {
TfLiteTensor* bw_output = GetOutput(context, node, kBwOutputTensor);
TfLiteTensor* bw_output;
TF_LITE_ENSURE_OK(
context, GetOutputSafe(context, node, kBwOutputTensor, &bw_output));
TfLiteIntArray* bw_output_size_array = TfLiteIntArrayCreate(3);
bw_output_size_array->data[0] = batch_size;
bw_output_size_array->data[1] = max_time;
@ -678,17 +711,28 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
const auto* params = reinterpret_cast<TfLiteBidirectionalSequenceRNNParams*>(
node->builtin_data);
const TfLiteTensor* input = GetInput(context, node, kInputTensor);
const TfLiteTensor* fw_input_weights =
GetInput(context, node, kFwWeightsTensor);
const TfLiteTensor* fw_recurrent_weights =
GetInput(context, node, kFwRecurrentWeightsTensor);
const TfLiteTensor* fw_bias = GetInput(context, node, kFwBiasTensor);
const TfLiteTensor* bw_input_weights =
GetInput(context, node, kBwWeightsTensor);
const TfLiteTensor* bw_recurrent_weights =
GetInput(context, node, kBwRecurrentWeightsTensor);
const TfLiteTensor* bw_bias = GetInput(context, node, kBwBiasTensor);
const TfLiteTensor* input;
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputTensor, &input));
const TfLiteTensor* fw_input_weights;
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kFwWeightsTensor,
&fw_input_weights));
const TfLiteTensor* fw_recurrent_weights;
TF_LITE_ENSURE_OK(context,
GetInputSafe(context, node, kFwRecurrentWeightsTensor,
&fw_recurrent_weights));
const TfLiteTensor* fw_bias;
TF_LITE_ENSURE_OK(context,
GetInputSafe(context, node, kFwBiasTensor, &fw_bias));
const TfLiteTensor* bw_input_weights;
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kBwWeightsTensor,
&bw_input_weights));
const TfLiteTensor* bw_recurrent_weights;
TF_LITE_ENSURE_OK(context,
GetInputSafe(context, node, kBwRecurrentWeightsTensor,
&bw_recurrent_weights));
const TfLiteTensor* bw_bias;
TF_LITE_ENSURE_OK(context,
GetInputSafe(context, node, kBwBiasTensor, &bw_bias));
// Get auxiliary inputs.
const TfLiteTensor* aux_input =
@ -700,12 +744,14 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
TfLiteTensor* fw_hidden_state =
GetVariableInput(context, node, kFwHiddenStateTensor);
TF_LITE_ENSURE(context, fw_hidden_state != nullptr);
TFLITE_DCHECK(fw_hidden_state != nullptr);
TfLiteTensor* bw_hidden_state =
GetVariableInput(context, node, kBwHiddenStateTensor);
TF_LITE_ENSURE(context, bw_hidden_state != nullptr);
TFLITE_DCHECK(bw_hidden_state != nullptr);
TfLiteTensor* fw_output = GetOutput(context, node, kFwOutputTensor);
TfLiteTensor* fw_output;
TF_LITE_ENSURE_OK(context,
GetOutputSafe(context, node, kFwOutputTensor, &fw_output));
TfLiteTensor* bw_output = params->merge_outputs
? nullptr
: GetOutput(context, node, kBwOutputTensor);
@ -741,18 +787,34 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
bw_hidden_state, bw_output);
case kTfLiteUInt8:
case kTfLiteInt8: {
TfLiteTensor* input_quantized =
GetTemporary(context, node, kInputQuantized);
TfLiteTensor* fw_hidden_state_quantized =
GetTemporary(context, node, kFwHiddenStateQuantized);
TfLiteTensor* bw_hidden_state_quantized =
GetTemporary(context, node, kBwHiddenStateQuantized);
TfLiteTensor* scaling_factors =
GetTemporary(context, node, kScalingFactors);
TfLiteTensor* zero_points = GetTemporary(context, node, kZeroPoints);
TfLiteTensor* accum_scratch = GetTemporary(context, node, kAccumScratch);
TfLiteTensor* fw_row_sums = GetTemporary(context, node, kFwRowSums);
TfLiteTensor* bw_row_sums = GetTemporary(context, node, kBwRowSums);
TfLiteTensor* input_quantized;
TF_LITE_ENSURE_OK(
context,
GetTemporarySafe(context, node, kInputQuantized, &input_quantized));
TfLiteTensor* fw_hidden_state_quantized;
TF_LITE_ENSURE_OK(context,
GetTemporarySafe(context, node, kFwHiddenStateQuantized,
&fw_hidden_state_quantized));
TfLiteTensor* bw_hidden_state_quantized;
TF_LITE_ENSURE_OK(context,
GetTemporarySafe(context, node, kBwHiddenStateQuantized,
&bw_hidden_state_quantized));
TfLiteTensor* scaling_factors;
TF_LITE_ENSURE_OK(
context,
GetTemporarySafe(context, node, kScalingFactors, &scaling_factors));
TfLiteTensor* zero_points;
TF_LITE_ENSURE_OK(
context, GetTemporarySafe(context, node, kZeroPoints, &zero_points));
TfLiteTensor* accum_scratch;
TF_LITE_ENSURE_OK(context, GetTemporarySafe(context, node, kAccumScratch,
&accum_scratch));
TfLiteTensor* fw_row_sums;
TF_LITE_ENSURE_OK(
context, GetTemporarySafe(context, node, kFwRowSums, &fw_row_sums));
TfLiteTensor* bw_row_sums;
TF_LITE_ENSURE_OK(
context, GetTemporarySafe(context, node, kBwRowSums, &bw_row_sums));
TfLiteTensor* aux_input_quantized =
use_aux_input ? GetTemporary(context, node, kAuxInputQuantized)
: nullptr;

View File

@ -32,8 +32,11 @@ constexpr int kOutputTensor = 0;
TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_ENSURE_EQ(context, NumInputs(node), 1);
TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
const TfLiteTensor* input = GetInput(context, node, kInputTensor);
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
const TfLiteTensor* input;
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputTensor, &input));
TfLiteTensor* output;
TF_LITE_ENSURE_OK(context,
GetOutputSafe(context, node, kOutputTensor, &output));
// TODO(ahentz): these two checks would make the new implementation
// incompatible with some existing models, where params is not specified. It
@ -98,8 +101,11 @@ TfLiteStatus copyToTensor(TfLiteContext* context, const FromT* in,
}
TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
const TfLiteTensor* input = GetInput(context, node, kInputTensor);
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
const TfLiteTensor* input;
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputTensor, &input));
TfLiteTensor* output;
TF_LITE_ENSURE_OK(context,
GetOutputSafe(context, node, kOutputTensor, &output));
const int num_elements = NumElements(input);
TF_LITE_ENSURE_EQ(context, num_elements, NumElements(output));
switch (input->type) {

View File

@ -29,8 +29,11 @@ constexpr int kInputTensor = 0;
constexpr int kOutputTensor = 0;
TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
const TfLiteTensor* input = GetInput(context, node, kInputTensor);
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
const TfLiteTensor* input;
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputTensor, &input));
TfLiteTensor* output;
TF_LITE_ENSURE_OK(context,
GetOutputSafe(context, node, kOutputTensor, &output));
TF_LITE_ENSURE_EQ(context, NumInputs(node), 1);
TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
TF_LITE_ENSURE_TYPES_EQ(context, input->type, kTfLiteFloat32);
@ -40,8 +43,11 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
}
TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
const TfLiteTensor* input = GetInput(context, node, kInputTensor);
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
const TfLiteTensor* input;
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputTensor, &input));
TfLiteTensor* output;
TF_LITE_ENSURE_OK(context,
GetOutputSafe(context, node, kOutputTensor, &output));
if (input->type != kTfLiteFloat32) {
TF_LITE_UNSUPPORTED_TYPE(context, input->type, "Ceil");
}

View File

@ -41,9 +41,15 @@ TfLiteStatus ComparisonPrepareCommon(TfLiteContext* context, TfLiteNode* node,
TF_LITE_ENSURE_EQ(context, NumInputs(node), 2);
TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
const TfLiteTensor* input1 = GetInput(context, node, kInputTensor1);
const TfLiteTensor* input2 = GetInput(context, node, kInputTensor2);
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
const TfLiteTensor* input1;
TF_LITE_ENSURE_OK(context,
GetInputSafe(context, node, kInputTensor1, &input1));
const TfLiteTensor* input2;
TF_LITE_ENSURE_OK(context,
GetInputSafe(context, node, kInputTensor2, &input2));
TfLiteTensor* output;
TF_LITE_ENSURE_OK(context,
GetOutputSafe(context, node, kOutputTensor, &output));
// Don't support string.
if (!is_string_allowed) {
@ -145,9 +151,15 @@ void ComparisonString(bool (*opname)(const StringRef&, const StringRef&),
}
TfLiteStatus EqualEval(TfLiteContext* context, TfLiteNode* node) {
const TfLiteTensor* input1 = GetInput(context, node, kInputTensor1);
const TfLiteTensor* input2 = GetInput(context, node, kInputTensor2);
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
const TfLiteTensor* input1;
TF_LITE_ENSURE_OK(context,
GetInputSafe(context, node, kInputTensor1, &input1));
const TfLiteTensor* input2;
TF_LITE_ENSURE_OK(context,
GetInputSafe(context, node, kInputTensor2, &input2));
TfLiteTensor* output;
TF_LITE_ENSURE_OK(context,
GetOutputSafe(context, node, kOutputTensor, &output));
bool requires_broadcast = !HaveSameShapes(input1, input2);
switch (input1->type) {
case kTfLiteBool:
@ -189,9 +201,15 @@ TfLiteStatus EqualEval(TfLiteContext* context, TfLiteNode* node) {
}
TfLiteStatus NotEqualEval(TfLiteContext* context, TfLiteNode* node) {
const TfLiteTensor* input1 = GetInput(context, node, kInputTensor1);
const TfLiteTensor* input2 = GetInput(context, node, kInputTensor2);
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
const TfLiteTensor* input1;
TF_LITE_ENSURE_OK(context,
GetInputSafe(context, node, kInputTensor1, &input1));
const TfLiteTensor* input2;
TF_LITE_ENSURE_OK(context,
GetInputSafe(context, node, kInputTensor2, &input2));
TfLiteTensor* output;
TF_LITE_ENSURE_OK(context,
GetOutputSafe(context, node, kOutputTensor, &output));
bool requires_broadcast = !HaveSameShapes(input1, input2);
switch (input1->type) {
case kTfLiteBool:
@ -233,9 +251,15 @@ TfLiteStatus NotEqualEval(TfLiteContext* context, TfLiteNode* node) {
}
TfLiteStatus GreaterEval(TfLiteContext* context, TfLiteNode* node) {
const TfLiteTensor* input1 = GetInput(context, node, kInputTensor1);
const TfLiteTensor* input2 = GetInput(context, node, kInputTensor2);
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
const TfLiteTensor* input1;
TF_LITE_ENSURE_OK(context,
GetInputSafe(context, node, kInputTensor1, &input1));
const TfLiteTensor* input2;
TF_LITE_ENSURE_OK(context,
GetInputSafe(context, node, kInputTensor2, &input2));
TfLiteTensor* output;
TF_LITE_ENSURE_OK(context,
GetOutputSafe(context, node, kOutputTensor, &output));
bool requires_broadcast = !HaveSameShapes(input1, input2);
switch (input1->type) {
case kTfLiteFloat32:
@ -268,9 +292,15 @@ TfLiteStatus GreaterEval(TfLiteContext* context, TfLiteNode* node) {
}
TfLiteStatus GreaterEqualEval(TfLiteContext* context, TfLiteNode* node) {
const TfLiteTensor* input1 = GetInput(context, node, kInputTensor1);
const TfLiteTensor* input2 = GetInput(context, node, kInputTensor2);
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
const TfLiteTensor* input1;
TF_LITE_ENSURE_OK(context,
GetInputSafe(context, node, kInputTensor1, &input1));
const TfLiteTensor* input2;
TF_LITE_ENSURE_OK(context,
GetInputSafe(context, node, kInputTensor2, &input2));
TfLiteTensor* output;
TF_LITE_ENSURE_OK(context,
GetOutputSafe(context, node, kOutputTensor, &output));
bool requires_broadcast = !HaveSameShapes(input1, input2);
switch (input1->type) {
case kTfLiteFloat32:
@ -303,9 +333,15 @@ TfLiteStatus GreaterEqualEval(TfLiteContext* context, TfLiteNode* node) {
}
TfLiteStatus LessEval(TfLiteContext* context, TfLiteNode* node) {
const TfLiteTensor* input1 = GetInput(context, node, kInputTensor1);
const TfLiteTensor* input2 = GetInput(context, node, kInputTensor2);
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
const TfLiteTensor* input1;
TF_LITE_ENSURE_OK(context,
GetInputSafe(context, node, kInputTensor1, &input1));
const TfLiteTensor* input2;
TF_LITE_ENSURE_OK(context,
GetInputSafe(context, node, kInputTensor2, &input2));
TfLiteTensor* output;
TF_LITE_ENSURE_OK(context,
GetOutputSafe(context, node, kOutputTensor, &output));
bool requires_broadcast = !HaveSameShapes(input1, input2);
switch (input1->type) {
case kTfLiteFloat32:
@ -338,9 +374,15 @@ TfLiteStatus LessEval(TfLiteContext* context, TfLiteNode* node) {
}
TfLiteStatus LessEqualEval(TfLiteContext* context, TfLiteNode* node) {
const TfLiteTensor* input1 = GetInput(context, node, kInputTensor1);
const TfLiteTensor* input2 = GetInput(context, node, kInputTensor2);
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
const TfLiteTensor* input1;
TF_LITE_ENSURE_OK(context,
GetInputSafe(context, node, kInputTensor1, &input1));
const TfLiteTensor* input2;
TF_LITE_ENSURE_OK(context,
GetInputSafe(context, node, kInputTensor2, &input2));
TfLiteTensor* output;
TF_LITE_ENSURE_OK(context,
GetOutputSafe(context, node, kOutputTensor, &output));
bool requires_broadcast = !HaveSameShapes(input1, input2);
switch (input1->type) {
case kTfLiteFloat32:

View File

@ -45,7 +45,8 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
// The number of dimensions of the input tensors must match, and all
// dimensions except 'axis' must be equal.
const TfLiteTensor* t0 = GetInput(context, node, 0);
const TfLiteTensor* t0;
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 0, &t0));
TfLiteType input_type = t0->type;
if (axis < 0) axis += t0->dims->size;
TF_LITE_ENSURE(context, axis >= 0);
@ -63,7 +64,8 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
// will be the sum of inputs
int sum_axis = t0->dims->data[axis];
for (int i = 1; i < num_inputs; ++i) {
const TfLiteTensor* t = GetInput(context, node, i);
const TfLiteTensor* t;
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, i, &t));
TF_LITE_ENSURE_EQ(context, t->dims->size, t0->dims->size);
TF_LITE_ENSURE_EQ(context, t->type, input_type);
for (int d = 0; d < t0->dims->size; ++d) {
@ -80,7 +82,8 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
output_size->data[d] = (d == axis) ? sum_axis : t0->dims->data[d];
}
TfLiteTensor* output = GetOutput(context, node, 0);
TfLiteTensor* output;
TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, 0, &output));
TF_LITE_ENSURE_TYPES_EQ(context, output->type, input_type);
if (input_type == kTfLiteInt8) {
@ -88,7 +91,8 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
// is a restriction we introduced to Int8 kernels.
VectorOfTensors<int8_t> all_inputs(*context, *node->inputs);
for (int i = 0; i < node->inputs->size; ++i) {
const TfLiteTensor* t = GetInput(context, node, i);
const TfLiteTensor* t;
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, i, &t));
TF_LITE_ENSURE_EQ(context, t->params.scale, output->params.scale);
TF_LITE_ENSURE_EQ(context, t->params.zero_point,
output->params.zero_point);
@ -103,7 +107,8 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
auto* params =
reinterpret_cast<TfLiteConcatenationParams*>(node->builtin_data);
int axis = params->axis;
TfLiteTensor* output = GetOutput(context, node, 0);
TfLiteTensor* output;
TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, 0, &output));
if (axis < 0) axis += output->dims->size;
// TODO(ahentz): Creating 'all_inputs' below is not very efficient. We should

View File

@ -222,8 +222,10 @@ static TfLiteStatus AllocateTemporaryTensorsIfRequired(TfLiteContext* context,
OpData* data = reinterpret_cast<OpData*>(node->user_data);
TF_LITE_ENSURE(context, node->inputs->size >= 2);
const TfLiteTensor* input = GetInput(context, node, 0);
const TfLiteTensor* filter = GetInput(context, node, 1);
const TfLiteTensor* input;
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 0, &input));
const TfLiteTensor* filter;
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 1, &filter));
// If we're using the optimized multithreaded EigenTensor implementation of
// convolution, it expects the filter weights to be transposed compared to
@ -316,9 +318,12 @@ TfLiteStatus Prepare(KernelType kernel_type, TfLiteContext* context,
// Check number of inputs/outputs
TF_LITE_ENSURE(context, has_bias || node->inputs->size == 2);
TF_LITE_ENSURE_EQ(context, node->outputs->size, 1);
TfLiteTensor* output = GetOutput(context, node, 0);
const TfLiteTensor* input = GetInput(context, node, 0);
const TfLiteTensor* filter = GetInput(context, node, 1);
TfLiteTensor* output;
TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, 0, &output));
const TfLiteTensor* input;
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 0, &input));
const TfLiteTensor* filter;
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 1, &filter));
// Check dimensionality of input, filter
TF_LITE_ENSURE_EQ(context, input->dims->size, 4);
@ -340,7 +345,7 @@ TfLiteStatus Prepare(KernelType kernel_type, TfLiteContext* context,
TF_LITE_ENSURE(context, has_bias);
if (has_bias) {
bias = GetInput(context, node, 2);
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 2, &bias));
if (input_type == kTfLiteUInt8 || input_type == kTfLiteInt8) {
TF_LITE_ENSURE_TYPES_EQ(context, bias->type, kTfLiteInt32);
TF_LITE_ENSURE_EQ(context, bias->params.zero_point, 0);
@ -493,8 +498,10 @@ TfLiteStatus Prepare(KernelType kernel_type, TfLiteContext* context,
if (is_hybrid) {
node->temporaries->data[data->input_quantized_index] =
data->input_quantized_id;
TfLiteTensor* input_quantized =
GetTemporary(context, node, data->input_quantized_index);
TfLiteTensor* input_quantized;
TF_LITE_ENSURE_OK(
context, GetTemporarySafe(context, node, data->input_quantized_index,
&input_quantized));
input_quantized->type = kTfLiteInt8;
input_quantized->allocation_type = kTfLiteArenaRw;
if (!TfLiteIntArrayEqual(input_quantized->dims, input->dims)) {
@ -505,8 +512,10 @@ TfLiteStatus Prepare(KernelType kernel_type, TfLiteContext* context,
node->temporaries->data[data->scaling_factors_index] =
data->scaling_factors_id;
TfLiteTensor* scaling_factors =
GetTemporary(context, node, data->scaling_factors_index);
TfLiteTensor* scaling_factors;
TF_LITE_ENSURE_OK(
context, GetTemporarySafe(context, node, data->scaling_factors_index,
&scaling_factors));
scaling_factors->type = kTfLiteFloat32;
scaling_factors->allocation_type = kTfLiteArenaRw;
// Only one scale factor per batch is typically necessary. See optimized
@ -522,8 +531,10 @@ TfLiteStatus Prepare(KernelType kernel_type, TfLiteContext* context,
}
node->temporaries->data[data->accum_scratch_index] = data->accum_scratch_id;
TfLiteTensor* accum_scratch =
GetTemporary(context, node, data->accum_scratch_index);
TfLiteTensor* accum_scratch;
TF_LITE_ENSURE_OK(context,
GetTemporarySafe(context, node, data->accum_scratch_index,
&accum_scratch));
accum_scratch->type = kTfLiteInt32;
accum_scratch->allocation_type = kTfLiteArenaRw;
const int scratch_width = batches * out_height * out_width;
@ -545,8 +556,10 @@ TfLiteStatus Prepare(KernelType kernel_type, TfLiteContext* context,
context, affine_quantization->scale->size,
filter->dims->data[affine_quantization->quantized_dimension]);
node->temporaries->data[data->input_offset_index] = data->input_offset_id;
TfLiteTensor* input_offsets =
GetTemporary(context, node, data->input_offset_index);
TfLiteTensor* input_offsets;
TF_LITE_ENSURE_OK(
context, GetTemporarySafe(context, node, data->input_offset_index,
&input_offsets));
input_offsets->type = kTfLiteInt32;
input_offsets->allocation_type = kTfLiteArenaRw;
// See above comment for the need to allocate for height of inputs.
@ -560,8 +573,10 @@ TfLiteStatus Prepare(KernelType kernel_type, TfLiteContext* context,
input_offsets_size));
}
node->temporaries->data[data->row_sums_index] = data->row_sums_id;
TfLiteTensor* row_sums =
GetTemporary(context, node, data->row_sums_index);
TfLiteTensor* row_sums;
TF_LITE_ENSURE_OK(
context,
GetTemporarySafe(context, node, data->row_sums_index, &row_sums));
row_sums->type = kTfLiteInt32;
row_sums->allocation_type = kTfLiteArenaRwPersistent;
// See above comment for the need to allocate for height of inputs.
@ -802,23 +817,34 @@ void EvalFloat(TfLiteContext* context, TfLiteNode* node,
}
template <KernelType kernel_type>
void EvalHybridPerChannel(TfLiteContext* context, TfLiteNode* node,
TfLiteConvParams* params, OpData* data,
const TfLiteTensor* input, const TfLiteTensor* filter,
const TfLiteTensor* bias, TfLiteTensor* im2col,
TfLiteTensor* output) {
TfLiteStatus EvalHybridPerChannel(TfLiteContext* context, TfLiteNode* node,
TfLiteConvParams* params, OpData* data,
const TfLiteTensor* input,
const TfLiteTensor* filter,
const TfLiteTensor* bias,
TfLiteTensor* im2col, TfLiteTensor* output) {
float output_activation_min, output_activation_max;
CalculateActivationRange(params->activation, &output_activation_min,
&output_activation_max);
const int input_size = NumElements(input) / SizeOfDimension(input, 0);
const int batch_size = SizeOfDimension(input, 0);
int8_t* quantized_input_ptr_batch = GetTensorData<int8_t>(
GetTemporary(context, node, data->input_quantized_index));
float* scaling_factors_ptr = GetTensorData<float>(
GetTemporary(context, node, data->scaling_factors_index));
int32_t* input_offset_ptr = GetTensorData<int32_t>(
GetTemporary(context, node, data->input_offset_index));
TfLiteTensor* quantized_input_tensor;
TF_LITE_ENSURE_OK(context,
GetTemporarySafe(context, node, data->input_quantized_index,
&quantized_input_tensor));
int8_t* quantized_input_ptr_batch =
GetTensorData<int8_t>(quantized_input_tensor);
TfLiteTensor* scaling_factors_tensor;
TF_LITE_ENSURE_OK(context,
GetTemporarySafe(context, node, data->scaling_factors_index,
&scaling_factors_tensor));
float* scaling_factors_ptr = GetTensorData<float>(scaling_factors_tensor);
TfLiteTensor* input_offset_tensor;
TF_LITE_ENSURE_OK(context,
GetTemporarySafe(context, node, data->input_offset_index,
&input_offset_tensor));
int32_t* input_offset_ptr = GetTensorData<int32_t>(input_offset_tensor);
for (int b = 0; b < batch_size; ++b) {
const int offset = b * input_size;
@ -859,10 +885,14 @@ void EvalHybridPerChannel(TfLiteContext* context, TfLiteNode* node,
case kGenericOptimized:
case kMultithreadOptimized:
case kCblasOptimized: {
TfLiteTensor* row_sums =
GetTemporary(context, node, data->row_sums_index);
TfLiteTensor* scratch =
GetTemporary(context, node, data->accum_scratch_index);
TfLiteTensor* row_sums;
TF_LITE_ENSURE_OK(
context,
GetTemporarySafe(context, node, data->row_sums_index, &row_sums));
TfLiteTensor* scratch;
TF_LITE_ENSURE_OK(
context,
GetTemporarySafe(context, node, data->accum_scratch_index, &scratch));
optimized_ops::HybridConvPerChannel(
op_params, scaling_factors_ptr, GetTensorShape(input),
quantized_input_ptr_batch, GetTensorShape(filter), filter_ptr,
@ -877,14 +907,16 @@ void EvalHybridPerChannel(TfLiteContext* context, TfLiteNode* node,
break;
}
}
return kTfLiteOk;
}
template <KernelType kernel_type>
void EvalHybrid(TfLiteContext* context, TfLiteNode* node,
TfLiteConvParams* params, OpData* data,
const TfLiteTensor* input, const TfLiteTensor* filter,
const TfLiteTensor* bias, TfLiteTensor* im2col,
TfLiteTensor* accum_scratch, TfLiteTensor* output) {
TfLiteStatus EvalHybrid(TfLiteContext* context, TfLiteNode* node,
TfLiteConvParams* params, OpData* data,
const TfLiteTensor* input, const TfLiteTensor* filter,
const TfLiteTensor* bias, TfLiteTensor* im2col,
TfLiteTensor* accum_scratch, TfLiteTensor* output) {
float output_activation_min, output_activation_max;
CalculateActivationRange(params->activation, &output_activation_min,
&output_activation_max);
@ -893,10 +925,17 @@ void EvalHybrid(TfLiteContext* context, TfLiteNode* node,
const int batch_size = SizeOfDimension(input, 0);
const float* input_ptr = GetTensorData<float>(input);
int8_t* quantized_input_ptr_batch = GetTensorData<int8_t>(
GetTemporary(context, node, data->input_quantized_index));
float* scaling_factors_ptr = GetTensorData<float>(
GetTemporary(context, node, data->scaling_factors_index));
TfLiteTensor* quantized_input_tensor;
TF_LITE_ENSURE_OK(context,
GetTemporarySafe(context, node, data->input_quantized_index,
&quantized_input_tensor));
int8_t* quantized_input_ptr_batch =
GetTensorData<int8_t>(quantized_input_tensor);
TfLiteTensor* scaling_factors_tensor;
TF_LITE_ENSURE_OK(context,
GetTemporarySafe(context, node, data->scaling_factors_index,
&scaling_factors_tensor));
float* scaling_factors_ptr = GetTensorData<float>(scaling_factors_tensor);
// Per-batch input quantization for higher accuracy.
{
@ -939,6 +978,8 @@ void EvalHybrid(TfLiteContext* context, TfLiteNode* node,
break;
}
}
return kTfLiteOk;
}
template <KernelType kernel_type, TfLiteType input_type>
@ -946,9 +987,12 @@ TfLiteStatus EvalImpl(TfLiteContext* context, TfLiteNode* node) {
auto* params = reinterpret_cast<TfLiteConvParams*>(node->builtin_data);
OpData* data = reinterpret_cast<OpData*>(node->user_data);
TfLiteTensor* output = GetOutput(context, node, 0);
const TfLiteTensor* input = GetInput(context, node, 0);
const TfLiteTensor* filter = GetInput(context, node, 1);
TfLiteTensor* output;
TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, 0, &output));
const TfLiteTensor* input;
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 0, &input));
const TfLiteTensor* filter;
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 1, &filter));
bool has_bias = node->inputs->size == 3;
const TfLiteTensor* bias = has_bias ? GetInput(context, node, 2) : nullptr;
TfLiteTensor* im2col =
@ -970,14 +1014,17 @@ TfLiteStatus EvalImpl(TfLiteContext* context, TfLiteNode* node) {
case kTfLiteFloat32:
if (filter->type == kTfLiteUInt8 || filter->type == kTfLiteInt8) {
if (data->is_hybrid_per_channel) {
EvalHybridPerChannel<kernel_type>(context, node, params, data, input,
filter, bias, im2col, output);
TF_LITE_ENSURE_OK(context, EvalHybridPerChannel<kernel_type>(
context, node, params, data, input,
filter, bias, im2col, output));
} else {
TfLiteTensor* accum_scratch =
&context->tensors[node->temporaries
->data[data->accum_scratch_index]];
EvalHybrid<kernel_type>(context, node, params, data, input, filter,
bias, im2col, accum_scratch, output);
TF_LITE_ENSURE_OK(context,
EvalHybrid<kernel_type>(context, node, params, data,
input, filter, bias, im2col,
accum_scratch, output));
}
} else {
EvalFloat<kernel_type>(context, node, params, data, input, filter, bias,
@ -1006,7 +1053,8 @@ TfLiteStatus EvalImpl(TfLiteContext* context, TfLiteNode* node) {
template <KernelType kernel_type>
TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
const TfLiteTensor* input = GetInput(context, node, 0);
const TfLiteTensor* input;
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 0, &input));
switch (input->type) {
case kTfLiteFloat32:

View File

@ -45,8 +45,11 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_ENSURE_EQ(context, NumInputs(node), 1);
TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
const TfLiteTensor* input = GetInput(context, node, kInputTensor);
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
const TfLiteTensor* input;
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputTensor, &input));
TfLiteTensor* output;
TF_LITE_ENSURE_OK(context,
GetOutputSafe(context, node, kOutputTensor, &output));
TF_LITE_ENSURE_EQ(context, NumDimensions(input), 4);
@ -84,8 +87,11 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
auto* params =
reinterpret_cast<TfLiteDepthToSpaceParams*>(node->builtin_data);
const TfLiteTensor* input = GetInput(context, node, kInputTensor);
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
const TfLiteTensor* input;
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputTensor, &input));
TfLiteTensor* output;
TF_LITE_ENSURE_OK(context,
GetOutputSafe(context, node, kOutputTensor, &output));
#define TF_LITE_DEPTH_TO_SPACE(type, scalar) \
tflite::DepthToSpaceParams op_params; \

View File

@ -104,12 +104,17 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
bool hasBias = NumInputs(node) == 3;
TF_LITE_ENSURE(context, hasBias || NumInputs(node) == 2);
const TfLiteTensor* input = GetInput(context, node, kInputTensor);
const TfLiteTensor* filter = GetInput(context, node, kFilterTensor);
const TfLiteTensor* input;
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputTensor, &input));
const TfLiteTensor* filter;
TF_LITE_ENSURE_OK(context,
GetInputSafe(context, node, kFilterTensor, &filter));
const TfLiteTensor* bias = nullptr;
TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
TfLiteTensor* output;
TF_LITE_ENSURE_OK(context,
GetOutputSafe(context, node, kOutputTensor, &output));
TF_LITE_ENSURE_EQ(context, NumDimensions(input), 4);
TF_LITE_ENSURE_EQ(context, NumDimensions(filter), 4);
@ -132,7 +137,7 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_ENSURE_EQ(context, SizeOfDimension(filter, 0), 1);
if (hasBias) {
bias = GetInput(context, node, kBiasTensor);
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kBiasTensor, &bias));
if (data_type == kTfLiteUInt8 || data_type == kTfLiteInt8) {
TF_LITE_ENSURE_TYPES_EQ(context, bias->type, kTfLiteInt32);
TF_LITE_ENSURE_EQ(context, bias->params.zero_point, 0);
@ -224,8 +229,10 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
node->temporaries->data[data->input_quantized_index] =
data->input_quantized_id;
TfLiteTensor* input_quantized =
GetTemporary(context, node, data->input_quantized_index);
TfLiteTensor* input_quantized;
TF_LITE_ENSURE_OK(
context, GetTemporarySafe(context, node, data->input_quantized_index,
&input_quantized));
input_quantized->type = kTfLiteInt8;
input_quantized->allocation_type = kTfLiteArenaRw;
if (!TfLiteIntArrayEqual(input_quantized->dims, input->dims)) {
@ -235,8 +242,10 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
}
node->temporaries->data[data->scaling_factors_index] =
data->scaling_factors_id;
TfLiteTensor* scaling_factors =
GetTemporary(context, node, data->scaling_factors_index);
TfLiteTensor* scaling_factors;
TF_LITE_ENSURE_OK(
context, GetTemporarySafe(context, node, data->scaling_factors_index,
&scaling_factors));
scaling_factors->type = kTfLiteFloat32;
scaling_factors->allocation_type = kTfLiteArenaRw;
const int batch_size = SizeOfDimension(input, 0);
@ -248,8 +257,10 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
scaling_factors_size));
}
node->temporaries->data[data->input_offset_index] = data->input_offset_id;
TfLiteTensor* input_offsets =
GetTemporary(context, node, data->input_offset_index);
TfLiteTensor* input_offsets;
TF_LITE_ENSURE_OK(context,
GetTemporarySafe(context, node, data->input_offset_index,
&input_offsets));
input_offsets->type = kTfLiteInt32;
input_offsets->allocation_type = kTfLiteArenaRw;
if (!TfLiteIntArrayEqualsArray(input_offsets->dims, 1, scaling_dims)) {
@ -446,13 +457,21 @@ TfLiteStatus EvalHybridPerChannel(TfLiteContext* context, TfLiteNode* node,
&output_activation_max);
const int input_size = NumElements(input) / SizeOfDimension(input, 0);
const int batch_size = SizeOfDimension(input, 0);
const TfLiteTensor* input_quantized =
GetTemporary(context, node, data->input_quantized_index);
TfLiteTensor* input_quantized;
TF_LITE_ENSURE_OK(context,
GetTemporarySafe(context, node, data->input_quantized_index,
&input_quantized));
int8_t* quantized_input_ptr_batch = input_quantized->data.int8;
float* scaling_factors_ptr = GetTensorData<float>(
GetTemporary(context, node, data->scaling_factors_index));
int32_t* input_offset_ptr = GetTensorData<int32_t>(
GetTemporary(context, node, data->input_offset_index));
TfLiteTensor* scaling_factors_tensor;
TF_LITE_ENSURE_OK(context,
GetTemporarySafe(context, node, data->scaling_factors_index,
&scaling_factors_tensor));
float* scaling_factors_ptr = GetTensorData<float>(scaling_factors_tensor);
TfLiteTensor* input_offset_tensor;
TF_LITE_ENSURE_OK(context,
GetTemporarySafe(context, node, data->input_offset_index,
&input_offset_tensor));
int32_t* input_offset_ptr = GetTensorData<int32_t>(input_offset_tensor);
for (int b = 0; b < batch_size; ++b) {
const int offset = b * input_size;
@ -504,9 +523,14 @@ TfLiteStatus EvalImpl(TfLiteContext* context, TfLiteNode* node) {
reinterpret_cast<TfLiteDepthwiseConvParams*>(node->builtin_data);
OpData* data = reinterpret_cast<OpData*>(node->user_data);
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
const TfLiteTensor* input = GetInput(context, node, kInputTensor);
const TfLiteTensor* filter = GetInput(context, node, kFilterTensor);
TfLiteTensor* output;
TF_LITE_ENSURE_OK(context,
GetOutputSafe(context, node, kOutputTensor, &output));
const TfLiteTensor* input;
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputTensor, &input));
const TfLiteTensor* filter;
TF_LITE_ENSURE_OK(context,
GetInputSafe(context, node, kFilterTensor, &filter));
const TfLiteTensor* bias =
(NumInputs(node) == 3) ? GetInput(context, node, kBiasTensor) : nullptr;
TFLITE_DCHECK_EQ(input_type, input->type);
@ -547,7 +571,8 @@ TfLiteStatus EvalImpl(TfLiteContext* context, TfLiteNode* node) {
template <KernelType kernel_type>
TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
const TfLiteTensor* input = GetInput(context, node, kInputTensor);
const TfLiteTensor* input;
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputTensor, &input));
switch (input->type) { // Already know in/out types are same.
case kTfLiteFloat32:

View File

@ -146,12 +146,17 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
auto* op_data = static_cast<OpData*>(node->user_data);
// Inputs: box_encodings, scores, anchors
TF_LITE_ENSURE_EQ(context, NumInputs(node), 3);
const TfLiteTensor* input_box_encodings =
GetInput(context, node, kInputTensorBoxEncodings);
const TfLiteTensor* input_class_predictions =
GetInput(context, node, kInputTensorClassPredictions);
const TfLiteTensor* input_anchors =
GetInput(context, node, kInputTensorAnchors);
const TfLiteTensor* input_box_encodings;
TF_LITE_ENSURE_OK(context,
GetInputSafe(context, node, kInputTensorBoxEncodings,
&input_box_encodings));
const TfLiteTensor* input_class_predictions;
TF_LITE_ENSURE_OK(context,
GetInputSafe(context, node, kInputTensorClassPredictions,
&input_class_predictions));
const TfLiteTensor* input_anchors;
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputTensorAnchors,
&input_anchors));
TF_LITE_ENSURE_EQ(context, NumDimensions(input_box_encodings), 3);
TF_LITE_ENSURE_EQ(context, NumDimensions(input_class_predictions), 3);
TF_LITE_ENSURE_EQ(context, NumDimensions(input_anchors), 2);
@ -163,27 +168,35 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
// num_detections
TF_LITE_ENSURE_EQ(context, NumOutputs(node), 4);
// Output Tensor detection_boxes: size is set to (1, num_detected_boxes, 4)
TfLiteTensor* detection_boxes =
GetOutput(context, node, kOutputTensorDetectionBoxes);
TfLiteTensor* detection_boxes;
TF_LITE_ENSURE_OK(context,
GetOutputSafe(context, node, kOutputTensorDetectionBoxes,
&detection_boxes));
detection_boxes->type = kTfLiteFloat32;
SetTensorSizes(context, detection_boxes,
{kBatchSize, num_detected_boxes, kNumCoordBox});
// Output Tensor detection_classes: size is set to (1, num_detected_boxes)
TfLiteTensor* detection_classes =
GetOutput(context, node, kOutputTensorDetectionClasses);
TfLiteTensor* detection_classes;
TF_LITE_ENSURE_OK(context,
GetOutputSafe(context, node, kOutputTensorDetectionClasses,
&detection_classes));
detection_classes->type = kTfLiteFloat32;
SetTensorSizes(context, detection_classes, {kBatchSize, num_detected_boxes});
// Output Tensor detection_scores: size is set to (1, num_detected_boxes)
TfLiteTensor* detection_scores =
GetOutput(context, node, kOutputTensorDetectionScores);
TfLiteTensor* detection_scores;
TF_LITE_ENSURE_OK(context,
GetOutputSafe(context, node, kOutputTensorDetectionScores,
&detection_scores));
detection_scores->type = kTfLiteFloat32;
SetTensorSizes(context, detection_scores, {kBatchSize, num_detected_boxes});
// Output Tensor num_detections: size is set to 1
TfLiteTensor* num_detections =
GetOutput(context, node, kOutputTensorNumDetections);
TfLiteTensor* num_detections;
TF_LITE_ENSURE_OK(context,
GetOutputSafe(context, node, kOutputTensorNumDetections,
&num_detections));
num_detections->type = kTfLiteFloat32;
// TODO (chowdhery): Make it a scalar when available
SetTensorSizes(context, num_detections, {1});
@ -269,13 +282,16 @@ T ReInterpretTensor(TfLiteTensor* tensor) {
TfLiteStatus DecodeCenterSizeBoxes(TfLiteContext* context, TfLiteNode* node,
OpData* op_data) {
// Parse input tensor boxencodings
const TfLiteTensor* input_box_encodings =
GetInput(context, node, kInputTensorBoxEncodings);
const TfLiteTensor* input_box_encodings;
TF_LITE_ENSURE_OK(context,
GetInputSafe(context, node, kInputTensorBoxEncodings,
&input_box_encodings));
TF_LITE_ENSURE_EQ(context, input_box_encodings->dims->data[0], kBatchSize);
const int num_boxes = input_box_encodings->dims->data[1];
TF_LITE_ENSURE(context, input_box_encodings->dims->data[2] >= kNumCoordBox);
const TfLiteTensor* input_anchors =
GetInput(context, node, kInputTensorAnchors);
const TfLiteTensor* input_anchors;
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputTensorAnchors,
&input_anchors));
// Decode the boxes to get (ymin, xmin, ymax, xmax) based on the anchors
CenterSizeEncoding box_centersize;
@ -389,8 +405,10 @@ TfLiteStatus NonMaxSuppressionSingleClassHelper(
TfLiteContext* context, TfLiteNode* node, OpData* op_data,
const std::vector<float>& scores, std::vector<int>* selected,
int max_detections) {
const TfLiteTensor* input_box_encodings =
GetInput(context, node, kInputTensorBoxEncodings);
const TfLiteTensor* input_box_encodings;
TF_LITE_ENSURE_OK(context,
GetInputSafe(context, node, kInputTensorBoxEncodings,
&input_box_encodings));
const TfLiteTensor* decoded_boxes =
&context->tensors[op_data->decoded_boxes_index];
const int num_boxes = input_box_encodings->dims->data[1];
@ -468,21 +486,33 @@ TfLiteStatus NonMaxSuppressionMultiClassRegularHelper(TfLiteContext* context,
TfLiteNode* node,
OpData* op_data,
const float* scores) {
const TfLiteTensor* input_box_encodings =
GetInput(context, node, kInputTensorBoxEncodings);
const TfLiteTensor* input_class_predictions =
GetInput(context, node, kInputTensorClassPredictions);
const TfLiteTensor* input_box_encodings;
TF_LITE_ENSURE_OK(context,
GetInputSafe(context, node, kInputTensorBoxEncodings,
&input_box_encodings));
const TfLiteTensor* input_class_predictions;
TF_LITE_ENSURE_OK(context,
GetInputSafe(context, node, kInputTensorClassPredictions,
&input_class_predictions));
const TfLiteTensor* decoded_boxes =
&context->tensors[op_data->decoded_boxes_index];
TfLiteTensor* detection_boxes =
GetOutput(context, node, kOutputTensorDetectionBoxes);
TfLiteTensor* detection_classes =
GetOutput(context, node, kOutputTensorDetectionClasses);
TfLiteTensor* detection_scores =
GetOutput(context, node, kOutputTensorDetectionScores);
TfLiteTensor* num_detections =
GetOutput(context, node, kOutputTensorNumDetections);
TfLiteTensor* detection_boxes;
TF_LITE_ENSURE_OK(context,
GetOutputSafe(context, node, kOutputTensorDetectionBoxes,
&detection_boxes));
TfLiteTensor* detection_classes;
TF_LITE_ENSURE_OK(context,
GetOutputSafe(context, node, kOutputTensorDetectionClasses,
&detection_classes));
TfLiteTensor* detection_scores;
TF_LITE_ENSURE_OK(context,
GetOutputSafe(context, node, kOutputTensorDetectionScores,
&detection_scores));
TfLiteTensor* num_detections;
TF_LITE_ENSURE_OK(context,
GetOutputSafe(context, node, kOutputTensorNumDetections,
&num_detections));
const int num_boxes = input_box_encodings->dims->data[1];
const int num_classes = op_data->num_classes;
@ -595,21 +625,33 @@ TfLiteStatus NonMaxSuppressionMultiClassFastHelper(TfLiteContext* context,
TfLiteNode* node,
OpData* op_data,
const float* scores) {
const TfLiteTensor* input_box_encodings =
GetInput(context, node, kInputTensorBoxEncodings);
const TfLiteTensor* input_class_predictions =
GetInput(context, node, kInputTensorClassPredictions);
const TfLiteTensor* input_box_encodings;
TF_LITE_ENSURE_OK(context,
GetInputSafe(context, node, kInputTensorBoxEncodings,
&input_box_encodings));
const TfLiteTensor* input_class_predictions;
TF_LITE_ENSURE_OK(context,
GetInputSafe(context, node, kInputTensorClassPredictions,
&input_class_predictions));
const TfLiteTensor* decoded_boxes =
&context->tensors[op_data->decoded_boxes_index];
TfLiteTensor* detection_boxes =
GetOutput(context, node, kOutputTensorDetectionBoxes);
TfLiteTensor* detection_classes =
GetOutput(context, node, kOutputTensorDetectionClasses);
TfLiteTensor* detection_scores =
GetOutput(context, node, kOutputTensorDetectionScores);
TfLiteTensor* num_detections =
GetOutput(context, node, kOutputTensorNumDetections);
TfLiteTensor* detection_boxes;
TF_LITE_ENSURE_OK(context,
GetOutputSafe(context, node, kOutputTensorDetectionBoxes,
&detection_boxes));
TfLiteTensor* detection_classes;
TF_LITE_ENSURE_OK(context,
GetOutputSafe(context, node, kOutputTensorDetectionClasses,
&detection_classes));
TfLiteTensor* detection_scores;
TF_LITE_ENSURE_OK(context,
GetOutputSafe(context, node, kOutputTensorDetectionScores,
&detection_scores));
TfLiteTensor* num_detections;
TF_LITE_ENSURE_OK(context,
GetOutputSafe(context, node, kOutputTensorNumDetections,
&num_detections));
const int num_boxes = input_box_encodings->dims->data[1];
const int num_classes = op_data->num_classes;
@ -680,10 +722,14 @@ void DequantizeClassPredictions(const TfLiteTensor* input_class_predictions,
TfLiteStatus NonMaxSuppressionMultiClass(TfLiteContext* context,
TfLiteNode* node, OpData* op_data) {
// Get the input tensors
const TfLiteTensor* input_box_encodings =
GetInput(context, node, kInputTensorBoxEncodings);
const TfLiteTensor* input_class_predictions =
GetInput(context, node, kInputTensorClassPredictions);
const TfLiteTensor* input_box_encodings;
TF_LITE_ENSURE_OK(context,
GetInputSafe(context, node, kInputTensorBoxEncodings,
&input_box_encodings));
const TfLiteTensor* input_class_predictions;
TF_LITE_ENSURE_OK(context,
GetInputSafe(context, node, kInputTensorClassPredictions,
&input_class_predictions));
const int num_boxes = input_box_encodings->dims->data[1];
const int num_classes = op_data->num_classes;
TF_LITE_ENSURE_EQ(context, input_class_predictions->dims->data[0],

View File

@ -74,9 +74,15 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_ENSURE_EQ(context, NumInputs(node), 2);
TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
const TfLiteTensor* input1 = GetInput(context, node, kInputTensor1);
const TfLiteTensor* input2 = GetInput(context, node, kInputTensor2);
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
const TfLiteTensor* input1;
TF_LITE_ENSURE_OK(context,
GetInputSafe(context, node, kInputTensor1, &input1));
const TfLiteTensor* input2;
TF_LITE_ENSURE_OK(context,
GetInputSafe(context, node, kInputTensor2, &input2));
TfLiteTensor* output;
TF_LITE_ENSURE_OK(context,
GetOutputSafe(context, node, kOutputTensor, &output));
TF_LITE_ENSURE_TYPES_EQ(context, input1->type, input2->type);
output->type = input2->type;
@ -200,9 +206,15 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
auto* params = reinterpret_cast<TfLiteDivParams*>(node->builtin_data);
OpData* data = reinterpret_cast<OpData*>(node->user_data);
const TfLiteTensor* input1 = GetInput(context, node, kInputTensor1);
const TfLiteTensor* input2 = GetInput(context, node, kInputTensor2);
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
const TfLiteTensor* input1;
TF_LITE_ENSURE_OK(context,
GetInputSafe(context, node, kInputTensor1, &input1));
const TfLiteTensor* input2;
TF_LITE_ENSURE_OK(context,
GetInputSafe(context, node, kInputTensor2, &input2));
TfLiteTensor* output;
TF_LITE_ENSURE_OK(context,
GetOutputSafe(context, node, kOutputTensor, &output));
if (output->type == kTfLiteFloat32 || output->type == kTfLiteInt32) {
EvalDiv<kernel_type>(context, node, params, data, input1, input2, output);

View File

@ -66,8 +66,10 @@ template <IsSupportedType is_supported_type, const char* op_name>
TfLiteStatus GenericPrepare(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_ENSURE_EQ(context, NumInputs(node), 1);
TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
const TfLiteTensor* input = GetInput(context, node, 0);
TfLiteTensor* output = GetOutput(context, node, 0);
const TfLiteTensor* input;
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 0, &input));
TfLiteTensor* output;
TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, 0, &output));
TF_LITE_ENSURE_TYPES_EQ(context, input->type, output->type);
if (!is_supported_type(input->type)) {
TF_LITE_UNSUPPORTED_TYPE(context, input->type, op_name);
@ -114,8 +116,10 @@ template <typename T>
inline TfLiteStatus EvalImpl(TfLiteContext* context, TfLiteNode* node,
std::function<T(T)> func,
TfLiteType expected_type) {
const TfLiteTensor* input = GetInput(context, node, 0);
TfLiteTensor* output = GetOutput(context, node, 0);
const TfLiteTensor* input;
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 0, &input));
TfLiteTensor* output;
TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, 0, &output));
TF_LITE_ENSURE_TYPES_EQ(context, input->type, expected_type);
const int64_t num_elements = NumElements(input);
const T* in_data = GetTensorData<T>(input);

View File

@ -46,14 +46,17 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_ENSURE_EQ(context, NumInputs(node), 2);
TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
const TfLiteTensor* lookup = GetInput(context, node, 0);
const TfLiteTensor* lookup;
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 0, &lookup));
TF_LITE_ENSURE_EQ(context, NumDimensions(lookup), 1);
TF_LITE_ENSURE_EQ(context, lookup->type, kTfLiteInt32);
const TfLiteTensor* value = GetInput(context, node, 1);
const TfLiteTensor* value;
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 1, &value));
TF_LITE_ENSURE(context, NumDimensions(value) >= 2);
TfLiteTensor* output = GetOutput(context, node, 0);
TfLiteTensor* output;
TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, 0, &output));
TfLiteIntArray* outputSize = TfLiteIntArrayCreate(NumDimensions(value));
outputSize->data[0] = SizeOfDimension(lookup, 0);
@ -129,9 +132,12 @@ TfLiteStatus EvalHybrid(TfLiteContext* context, TfLiteNode* node,
}
TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
const TfLiteTensor* lookup = GetInput(context, node, 0);
const TfLiteTensor* value = GetInput(context, node, 1);
TfLiteTensor* output = GetOutput(context, node, 0);
const TfLiteTensor* lookup;
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 0, &lookup));
const TfLiteTensor* value;
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 1, &value));
TfLiteTensor* output;
TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, 0, &output));
switch (value->type) {
case kTfLiteFloat32:
return EvalSimple(context, node, lookup, value, output);

View File

@ -83,19 +83,23 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_ENSURE_EQ(context, NumInputs(node), 5);
TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
const TfLiteTensor* ids = GetInput(context, node, 0);
const TfLiteTensor* ids;
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 0, &ids));
TF_LITE_ENSURE_EQ(context, NumDimensions(ids), 1);
TF_LITE_ENSURE_EQ(context, ids->type, kTfLiteInt32);
const TfLiteTensor* indices = GetInput(context, node, 1);
const TfLiteTensor* indices;
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 1, &indices));
TF_LITE_ENSURE_EQ(context, NumDimensions(indices), 2);
TF_LITE_ENSURE_EQ(context, indices->type, kTfLiteInt32);
const TfLiteTensor* shape = GetInput(context, node, 2);
const TfLiteTensor* shape;
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 2, &shape));
TF_LITE_ENSURE_EQ(context, NumDimensions(shape), 1);
TF_LITE_ENSURE_EQ(context, shape->type, kTfLiteInt32);
const TfLiteTensor* weights = GetInput(context, node, 3);
const TfLiteTensor* weights;
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 3, &weights));
TF_LITE_ENSURE_EQ(context, NumDimensions(weights), 1);
TF_LITE_ENSURE_EQ(context, weights->type, kTfLiteFloat32);
@ -104,11 +108,13 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_ENSURE_EQ(context, SizeOfDimension(indices, 0),
SizeOfDimension(weights, 0));
const TfLiteTensor* value = GetInput(context, node, 4);
const TfLiteTensor* value;
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 4, &value));
TF_LITE_ENSURE(context, NumDimensions(value) >= 2);
// Mark the output as a dynamic tensor.
TfLiteTensor* output = GetOutput(context, node, 0);
TfLiteTensor* output;
TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, 0, &output));
TF_LITE_ENSURE_TYPES_EQ(context, output->type, kTfLiteFloat32);
output->allocation_type = kTfLiteDynamic;
@ -140,12 +146,18 @@ void FinalizeAggregation(TfLiteCombinerType combiner, int num_elements,
TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
auto* params =
reinterpret_cast<TfLiteEmbeddingLookupSparseParams*>(node->builtin_data);
TfLiteTensor* output = GetOutput(context, node, 0);
const TfLiteTensor* ids = GetInput(context, node, 0);
const TfLiteTensor* indices = GetInput(context, node, 1);
const TfLiteTensor* dense_shape = GetInput(context, node, 2);
const TfLiteTensor* weights = GetInput(context, node, 3);
const TfLiteTensor* value = GetInput(context, node, 4);
TfLiteTensor* output;
TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, 0, &output));
const TfLiteTensor* ids;
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 0, &ids));
const TfLiteTensor* indices;
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 1, &indices));
const TfLiteTensor* dense_shape;
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 2, &dense_shape));
const TfLiteTensor* weights;
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 3, &weights));
const TfLiteTensor* value;
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 4, &value));
const int lookup_rank = SizeOfDimension(indices, 1);
const int embedding_rank = NumDimensions(value);

View File

@ -73,9 +73,12 @@ TfLiteStatus GetAxisValueFromTensor(TfLiteContext* context,
TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_ENSURE_EQ(context, NumInputs(node), 2);
TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
const TfLiteTensor* input = GetInput(context, node, kInput);
const TfLiteTensor* axis = GetInput(context, node, kAxis);
TfLiteTensor* output = GetOutput(context, node, 0);
const TfLiteTensor* input;
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInput, &input));
const TfLiteTensor* axis;
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kAxis, &axis));
TfLiteTensor* output;
TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, 0, &output));
output->type = input->type;
if (IsConstantTensor(axis)) {
int axis_value;
@ -89,9 +92,12 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
// Just copy input to output.
const TfLiteTensor* input = GetInput(context, node, kInput);
TfLiteTensor* output = GetOutput(context, node, 0);
const TfLiteTensor* axis = GetInput(context, node, kAxis);
const TfLiteTensor* input;
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInput, &input));
TfLiteTensor* output;
TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, 0, &output));
const TfLiteTensor* axis;
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kAxis, &axis));
if (IsDynamicTensor(output)) {
int axis_value;
TF_LITE_ENSURE_OK(context,

View File

@ -72,8 +72,10 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_ENSURE_EQ(context, NumInputs(node), 2);
TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
const TfLiteTensor* dims = GetInput(context, node, kDimsTensor);
const TfLiteTensor* value = GetInput(context, node, kValueTensor);
const TfLiteTensor* dims;
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kDimsTensor, &dims));
const TfLiteTensor* value;
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kValueTensor, &value));
// Make sure the 1st input tensor is 1-D.
TF_LITE_ENSURE_EQ(context, NumDimensions(dims), 1);
@ -85,7 +87,9 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
// Make sure the 2nd input tensor is a scalar.
TF_LITE_ENSURE_EQ(context, NumDimensions(value), 0);
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
TfLiteTensor* output;
TF_LITE_ENSURE_OK(context,
GetOutputSafe(context, node, kOutputTensor, &output));
output->type = value->type;
if (IsConstantTensor(dims)) {
@ -111,12 +115,16 @@ TfLiteStatus FillString(const TfLiteTensor* value, TfLiteTensor* output) {
}
TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
const TfLiteTensor* value = GetInput(context, node, kValueTensor);
const TfLiteTensor* value;
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kValueTensor, &value));
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
TfLiteTensor* output;
TF_LITE_ENSURE_OK(context,
GetOutputSafe(context, node, kOutputTensor, &output));
if (IsDynamicTensor(output)) {
const TfLiteTensor* dims = GetInput(context, node, kDimsTensor);
const TfLiteTensor* dims;
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kDimsTensor, &dims));
TF_LITE_ENSURE_OK(context, ResizeOutput(context, dims, output));
}
#define TF_LITE_FILL(data_type) \

View File

@ -35,8 +35,11 @@ enum KernelType {
};
TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
const TfLiteTensor* input = GetInput(context, node, kInputTensor);
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
const TfLiteTensor* input;
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputTensor, &input));
TfLiteTensor* output;
TF_LITE_ENSURE_OK(context,
GetOutputSafe(context, node, kOutputTensor, &output));
TF_LITE_ENSURE_EQ(context, NumInputs(node), 1);
TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
TF_LITE_ENSURE_TYPES_EQ(context, input->type, kTfLiteFloat32);
@ -47,8 +50,11 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
template <KernelType type>
TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
const TfLiteTensor* input = GetInput(context, node, kInputTensor);
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
const TfLiteTensor* input;
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputTensor, &input));
TfLiteTensor* output;
TF_LITE_ENSURE_OK(context,
GetOutputSafe(context, node, kOutputTensor, &output));
if (type == kGenericOptimized) {
optimized_ops::Floor(GetTensorShape(input), GetTensorData<float>(input),

View File

@ -64,9 +64,15 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
// Reinterprete the opaque data provided by user.
OpData* data = reinterpret_cast<OpData*>(node->user_data);
const TfLiteTensor* input1 = GetInput(context, node, kInputTensor1);
const TfLiteTensor* input2 = GetInput(context, node, kInputTensor2);
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
const TfLiteTensor* input1;
TF_LITE_ENSURE_OK(context,
GetInputSafe(context, node, kInputTensor1, &input1));
const TfLiteTensor* input2;
TF_LITE_ENSURE_OK(context,
GetInputSafe(context, node, kInputTensor2, &input2));
TfLiteTensor* output;
TF_LITE_ENSURE_OK(context,
GetOutputSafe(context, node, kOutputTensor, &output));
TF_LITE_ENSURE_TYPES_EQ(context, input1->type, input2->type);
@ -126,9 +132,15 @@ TfLiteStatus EvalImpl(TfLiteContext* context, bool requires_broadcast,
TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
OpData* data = reinterpret_cast<OpData*>(node->user_data);
const TfLiteTensor* input1 = GetInput(context, node, kInputTensor1);
const TfLiteTensor* input2 = GetInput(context, node, kInputTensor2);
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
const TfLiteTensor* input1;
TF_LITE_ENSURE_OK(context,
GetInputSafe(context, node, kInputTensor1, &input1));
const TfLiteTensor* input2;
TF_LITE_ENSURE_OK(context,
GetInputSafe(context, node, kInputTensor2, &input2));
TfLiteTensor* output;
TF_LITE_ENSURE_OK(context,
GetOutputSafe(context, node, kOutputTensor, &output));
switch (input1->type) {
case kTfLiteInt32: {

View File

@ -58,9 +58,15 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
// Reinterprete the opaque data provided by user.
OpData* data = reinterpret_cast<OpData*>(node->user_data);
const TfLiteTensor* input1 = GetInput(context, node, kInputTensor1);
const TfLiteTensor* input2 = GetInput(context, node, kInputTensor2);
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
const TfLiteTensor* input1;
TF_LITE_ENSURE_OK(context,
GetInputSafe(context, node, kInputTensor1, &input1));
const TfLiteTensor* input2;
TF_LITE_ENSURE_OK(context,
GetInputSafe(context, node, kInputTensor2, &input2));
TfLiteTensor* output;
TF_LITE_ENSURE_OK(context,
GetOutputSafe(context, node, kOutputTensor, &output));
TF_LITE_ENSURE_TYPES_EQ(context, input1->type, input2->type);
@ -120,9 +126,15 @@ TfLiteStatus EvalImpl(TfLiteContext* context, bool requires_broadcast,
TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
OpData* data = reinterpret_cast<OpData*>(node->user_data);
const TfLiteTensor* input1 = GetInput(context, node, kInputTensor1);
const TfLiteTensor* input2 = GetInput(context, node, kInputTensor2);
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
const TfLiteTensor* input1;
TF_LITE_ENSURE_OK(context,
GetInputSafe(context, node, kInputTensor1, &input1));
const TfLiteTensor* input2;
TF_LITE_ENSURE_OK(context,
GetInputSafe(context, node, kInputTensor2, &input2));
TfLiteTensor* output;
TF_LITE_ENSURE_OK(context,
GetOutputSafe(context, node, kOutputTensor, &output));
switch (input1->type) {
case kTfLiteInt32: {

View File

@ -155,13 +155,18 @@ TfLiteStatus PrepareImpl(TfLiteContext* context, TfLiteNode* node) {
: 2;
TF_LITE_ENSURE_EQ(context, node->outputs->size, expected_outputs_count);
const TfLiteTensor* input = GetInput(context, node, kInputTensor);
const TfLiteTensor* filter = GetInput(context, node, kWeightsTensor);
const TfLiteTensor* input;
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputTensor, &input));
const TfLiteTensor* filter;
TF_LITE_ENSURE_OK(context,
GetInputSafe(context, node, kWeightsTensor, &filter));
const TfLiteTensor* bias =
(node->inputs->size == 3)
? GetOptionalInputTensor(context, node, kBiasTensor)
: nullptr;
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
TfLiteTensor* output;
TF_LITE_ENSURE_OK(context,
GetOutputSafe(context, node, kOutputTensor, &output));
// Check proper datatype match among all Input Tensors
TF_LITE_ENSURE_STATUS(
@ -214,7 +219,9 @@ TfLiteStatus PrepareImpl(TfLiteContext* context, TfLiteNode* node) {
node->temporaries = TfLiteIntArrayCreate(5);
node->temporaries->data[0] = data->scratch_tensor_index;
TfLiteTensor* input_quantized = GetTemporary(context, node, /*index=*/0);
TfLiteTensor* input_quantized;
TF_LITE_ENSURE_OK(context, GetTemporarySafe(context, node, /*index=*/0,
&input_quantized));
input_quantized->type = filter->type;
input_quantized->allocation_type = kTfLiteArenaRw;
@ -223,7 +230,9 @@ TfLiteStatus PrepareImpl(TfLiteContext* context, TfLiteNode* node) {
input_quantized_size));
node->temporaries->data[1] = data->scratch_tensor_index + 1;
TfLiteTensor* scaling_factors = GetTemporary(context, node, /*index=*/1);
TfLiteTensor* scaling_factors;
TF_LITE_ENSURE_OK(context, GetTemporarySafe(context, node, /*index=*/1,
&scaling_factors));
scaling_factors->type = kTfLiteFloat32;
scaling_factors->allocation_type = kTfLiteArenaRw;
@ -236,7 +245,9 @@ TfLiteStatus PrepareImpl(TfLiteContext* context, TfLiteNode* node) {
}
node->temporaries->data[2] = data->scratch_tensor_index + 2;
TfLiteTensor* accum_scratch = GetTemporary(context, node, /*index=*/2);
TfLiteTensor* accum_scratch;
TF_LITE_ENSURE_OK(
context, GetTemporarySafe(context, node, /*index=*/2, &accum_scratch));
accum_scratch->type = kTfLiteInt32;
accum_scratch->allocation_type = kTfLiteArenaRw;
int accum_scratch_dims[2] = {num_units, batch_size};
@ -250,7 +261,9 @@ TfLiteStatus PrepareImpl(TfLiteContext* context, TfLiteNode* node) {
}
node->temporaries->data[3] = data->scratch_tensor_index + 3;
TfLiteTensor* input_offsets = GetTemporary(context, node, /*index=*/3);
TfLiteTensor* input_offsets;
TF_LITE_ENSURE_OK(
context, GetTemporarySafe(context, node, /*index=*/3, &input_offsets));
input_offsets->type = kTfLiteInt32;
input_offsets->allocation_type = kTfLiteArenaRw;
if (!TfLiteIntArrayEqualsArray(input_offsets->dims, 1, scaling_dims)) {
@ -260,7 +273,9 @@ TfLiteStatus PrepareImpl(TfLiteContext* context, TfLiteNode* node) {
input_offsets_size));
}
node->temporaries->data[4] = data->scratch_tensor_index + 4;
TfLiteTensor* row_sums = GetTemporary(context, node, /*index=*/4);
TfLiteTensor* row_sums;
TF_LITE_ENSURE_OK(context,
GetTemporarySafe(context, node, /*index=*/4, &row_sums));
row_sums->type = kTfLiteInt32;
row_sums->allocation_type = kTfLiteArenaRwPersistent;
int row_sums_dims[1] = {num_units};
@ -300,8 +315,11 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
// Check for supported activation types.
auto* params =
reinterpret_cast<TfLiteFullyConnectedParams*>(node->builtin_data);
const TfLiteTensor* filter = GetInput(context, node, kWeightsTensor);
const TfLiteTensor* input = GetInput(context, node, kInputTensor);
const TfLiteTensor* filter;
TF_LITE_ENSURE_OK(context,
GetInputSafe(context, node, kWeightsTensor, &filter));
const TfLiteTensor* input;
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputTensor, &input));
const bool is_quantized =
((filter->type == kTfLiteUInt8) || (filter->type == kTfLiteInt8));
const bool is_hybrid = is_quantized && (input->type == kTfLiteFloat32);
@ -484,11 +502,21 @@ TfLiteStatus EvalQuantized(TfLiteContext* context, TfLiteNode* node,
int32_t output_offset = output->params.zero_point;
// Only the Pie path supports quantized models and float inputs/outputs.
if (input->type == kTfLiteFloat32) {
TfLiteTensor* input_quantized = GetTemporary(context, node, /*index=*/0);
TfLiteTensor* scaling_factors = GetTemporary(context, node, /*index=*/1);
TfLiteTensor* accum_scratch = GetTemporary(context, node, /*index=*/2);
TfLiteTensor* input_offsets = GetTemporary(context, node, /*index=*/3);
TfLiteTensor* row_sums = GetTemporary(context, node, /*index=*/4);
TfLiteTensor* input_quantized;
TF_LITE_ENSURE_OK(context, GetTemporarySafe(context, node, /*index=*/0,
&input_quantized));
TfLiteTensor* scaling_factors;
TF_LITE_ENSURE_OK(context, GetTemporarySafe(context, node, /*index=*/1,
&scaling_factors));
TfLiteTensor* accum_scratch;
TF_LITE_ENSURE_OK(
context, GetTemporarySafe(context, node, /*index=*/2, &accum_scratch));
TfLiteTensor* input_offsets;
TF_LITE_ENSURE_OK(
context, GetTemporarySafe(context, node, /*index=*/3, &input_offsets));
TfLiteTensor* row_sums;
TF_LITE_ENSURE_OK(context,
GetTemporarySafe(context, node, /*index=*/4, &row_sums));
return EvalHybrid(context, node, params, data, input, filter, bias,
input_quantized, scaling_factors, accum_scratch, row_sums,
input_offsets, output);
@ -693,13 +721,18 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
reinterpret_cast<TfLiteFullyConnectedParams*>(node->builtin_data);
OpData* data = reinterpret_cast<OpData*>(node->user_data);
const TfLiteTensor* input = GetInput(context, node, kInputTensor);
const TfLiteTensor* filter = GetInput(context, node, kWeightsTensor);
const TfLiteTensor* input;
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputTensor, &input));
const TfLiteTensor* filter;
TF_LITE_ENSURE_OK(context,
GetInputSafe(context, node, kWeightsTensor, &filter));
const TfLiteTensor* bias =
(node->inputs->size == 3)
? GetOptionalInputTensor(context, node, kBiasTensor)
: nullptr;
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
TfLiteTensor* output;
TF_LITE_ENSURE_OK(context,
GetOutputSafe(context, node, kOutputTensor, &output));
switch (filter->type) {
case kTfLiteFloat32:
@ -708,8 +741,10 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
case kTfLiteUInt8:
if (params->weights_format ==
kTfLiteFullyConnectedWeightsFormatShuffled4x16Int8) {
TfLiteTensor* shuffled_input_workspace =
GetOutput(context, node, kShuffledInputWorkspaceTensor);
TfLiteTensor* shuffled_input_workspace;
TF_LITE_ENSURE_OK(
context, GetOutputSafe(context, node, kShuffledInputWorkspaceTensor,
&shuffled_input_workspace));
return EvalShuffledQuantized<kernel_type>(context, node, params, data,
input, filter, bias, output,
shuffled_input_workspace);

View File

@ -38,9 +38,14 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
const auto* params =
reinterpret_cast<const TfLiteGatherParams*>(node->builtin_data);
const TfLiteTensor* input = GetInput(context, node, kInputTensor);
const TfLiteTensor* positions = GetInput(context, node, kInputPositions);
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
const TfLiteTensor* input;
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputTensor, &input));
const TfLiteTensor* positions;
TF_LITE_ENSURE_OK(context,
GetInputSafe(context, node, kInputPositions, &positions));
TfLiteTensor* output;
TF_LITE_ENSURE_OK(context,
GetOutputSafe(context, node, kOutputTensor, &output));
switch (positions->type) {
case kTfLiteInt64:
@ -132,9 +137,14 @@ TfLiteStatus GatherStrings(TfLiteContext* context, const TfLiteTensor* input,
TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
const auto* params =
reinterpret_cast<const TfLiteGatherParams*>(node->builtin_data);
const TfLiteTensor* input = GetInput(context, node, kInputTensor);
const TfLiteTensor* positions = GetInput(context, node, kInputPositions);
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
const TfLiteTensor* input;
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputTensor, &input));
const TfLiteTensor* positions;
TF_LITE_ENSURE_OK(context,
GetInputSafe(context, node, kInputPositions, &positions));
TfLiteTensor* output;
TF_LITE_ENSURE_OK(context,
GetOutputSafe(context, node, kOutputTensor, &output));
if (positions->type == kTfLiteInt32) {
switch (input->type) {

View File

@ -33,9 +33,13 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_ENSURE_EQ(context, NumInputs(node), 2);
TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
const TfLiteTensor* params = GetInput(context, node, kParams);
const TfLiteTensor* indices = GetInput(context, node, kIndices);
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
const TfLiteTensor* params;
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kParams, &params));
const TfLiteTensor* indices;
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kIndices, &indices));
TfLiteTensor* output;
TF_LITE_ENSURE_OK(context,
GetOutputSafe(context, node, kOutputTensor, &output));
switch (params->type) {
case kTfLiteFloat32:
@ -140,9 +144,13 @@ TfLiteStatus EvalGatherNd(TfLiteContext* context, const TfLiteTensor* params,
}
TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
const TfLiteTensor* params = GetInput(context, node, kParams);
const TfLiteTensor* indices = GetInput(context, node, kIndices);
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
const TfLiteTensor* params;
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kParams, &params));
const TfLiteTensor* indices;
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kIndices, &indices));
TfLiteTensor* output;
TF_LITE_ENSURE_OK(context,
GetOutputSafe(context, node, kOutputTensor, &output));
switch (indices->type) {
case kTfLiteInt32:

View File

@ -37,6 +37,7 @@ limitations under the License.
#include <cstring>
#include "tensorflow/lite/c/common.h"
#include "tensorflow/lite/kernels/internal/compatibility.h"
#include "tensorflow/lite/kernels/kernel_util.h"
#include "tensorflow/lite/string_util.h"
@ -54,15 +55,18 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_ENSURE_EQ(context, NumInputs(node), 3);
TF_LITE_ENSURE_EQ(context, NumOutputs(node), 2);
const TfLiteTensor* lookup = GetInput(context, node, 0);
const TfLiteTensor* lookup;
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 0, &lookup));
TF_LITE_ENSURE_EQ(context, NumDimensions(lookup), 1);
TF_LITE_ENSURE_EQ(context, lookup->type, kTfLiteInt32);
const TfLiteTensor* key = GetInput(context, node, 1);
const TfLiteTensor* key;
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 1, &key));
TF_LITE_ENSURE_EQ(context, NumDimensions(key), 1);
TF_LITE_ENSURE_EQ(context, key->type, kTfLiteInt32);
const TfLiteTensor* value = GetInput(context, node, 2);
const TfLiteTensor* value;
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 2, &value));
TF_LITE_ENSURE(context, NumDimensions(value) >= 1);
TF_LITE_ENSURE_EQ(context, SizeOfDimension(key, 0),
SizeOfDimension(value, 0));
@ -70,12 +74,14 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_ENSURE_EQ(context, NumDimensions(value), 1);
}
TfLiteTensor* hits = GetOutput(context, node, 1);
TfLiteTensor* hits;
TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, 1, &hits));
TF_LITE_ENSURE_EQ(context, hits->type, kTfLiteUInt8);
TfLiteIntArray* hitSize = TfLiteIntArrayCreate(1);
hitSize->data[0] = SizeOfDimension(lookup, 0);
TfLiteTensor* output = GetOutput(context, node, 0);
TfLiteTensor* output;
TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, 0, &output));
TF_LITE_ENSURE_EQ(context, value->type, output->type);
TfLiteStatus status = kTfLiteOk;
@ -94,11 +100,16 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
}
TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
TfLiteTensor* output = GetOutput(context, node, 0);
TfLiteTensor* hits = GetOutput(context, node, 1);
const TfLiteTensor* lookup = GetInput(context, node, 0);
const TfLiteTensor* key = GetInput(context, node, 1);
const TfLiteTensor* value = GetInput(context, node, 2);
TfLiteTensor* output;
TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, 0, &output));
TfLiteTensor* hits;
TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, 1, &hits));
const TfLiteTensor* lookup;
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 0, &lookup));
const TfLiteTensor* key;
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 1, &key));
const TfLiteTensor* value;
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 2, &value));
const int num_rows = SizeOfDimension(value, 0);
const int row_bytes = value->bytes / num_rows;

View File

@ -22,6 +22,7 @@ limitations under the License.
#include "tensorflow/lite/c/builtin_op_data.h"
#include "tensorflow/lite/c/common.h"
#include "tensorflow/lite/core/subgraph.h"
#include "tensorflow/lite/kernels/internal/compatibility.h"
#include "tensorflow/lite/kernels/kernel_util.h"
namespace tflite {
@ -52,7 +53,8 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_ENSURE(context, node->inputs->size > 0);
// The first input is the condition.
const TfLiteTensor* cond = GetInput(context, node, 0);
const TfLiteTensor* cond;
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 0, &cond));
// Currently only bool is supported.
// TODO(ycling): Support other types since TensorFlow also support
// non-bool types as condition.
@ -83,7 +85,8 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
for (int i = 0; i < num_inputs; ++i) {
// The first input of the node is the condition. The indices of the inputs
// passed to the subgraphs are offset by 1.
const TfLiteTensor* input = GetInput(context, node, i + 1);
const TfLiteTensor* input;
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, i + 1, &input));
std::vector<int> dims(input->dims->data,
input->dims->data + input->dims->size);
subgraph->ResizeInputTensor(i, dims);
@ -113,7 +116,8 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
}
for (int i = 0; i < num_outputs; ++i) {
TfLiteTensor* output = GetOutput(context, node, i);
TfLiteTensor* output;
TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, i, &output));
if (has_dynamic_output_tensors) {
SetTensorToDynamic(output);
} else {
@ -133,7 +137,8 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
const OpData* op_data = reinterpret_cast<OpData*>(node->user_data);
const TfLiteTensor* cond = GetInput(context, node, 0);
const TfLiteTensor* cond;
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 0, &cond));
bool cond_value = cond->data.b[0];
Subgraph* this_subgraph = reinterpret_cast<Subgraph*>(context->impl_);
@ -147,7 +152,8 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
Subgraph& active_branch_subgraph =
*(*subgraphs)[active_branch_subgraph_index];
for (int i = 0; i < active_branch_subgraph.inputs().size(); ++i) {
const TfLiteTensor* input = GetInput(context, node, i + 1);
const TfLiteTensor* input;
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, i + 1, &input));
TfLiteTensor* subgraph_input =
active_branch_subgraph.tensor(active_branch_subgraph.inputs()[i]);
TF_LITE_ENSURE_EQ(context, input->bytes, subgraph_input->bytes);
@ -164,7 +170,8 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
bool has_dynamic_output_tensors = false;
for (int i = 0; i < node->outputs->size; ++i) {
TfLiteTensor* output = GetOutput(context, node, i);
TfLiteTensor* output;
TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, i, &output));
if (IsDynamicTensor(output)) {
has_dynamic_output_tensors = true;
break;
@ -173,7 +180,8 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
if (has_dynamic_output_tensors) {
for (int i = 0; i < node->outputs->size; ++i) {
TfLiteTensor* output = GetOutput(context, node, i);
TfLiteTensor* output;
TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, i, &output));
TfLiteTensor* subgraph_output =
active_branch_subgraph.tensor(active_branch_subgraph.outputs()[i]);
TfLiteIntArray* output_size = TfLiteIntArrayCopy(subgraph_output->dims);
@ -185,7 +193,8 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
for (int i = 0; i < active_branch_subgraph.outputs().size(); ++i) {
const TfLiteTensor* subgraph_output =
active_branch_subgraph.tensor(active_branch_subgraph.outputs()[i]);
TfLiteTensor* output = GetOutput(context, node, i);
TfLiteTensor* output;
TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, i, &output));
TF_LITE_ENSURE_EQ(context, output->bytes, subgraph_output->bytes);
memcpy(output->data.raw, subgraph_output->data.raw, output->bytes);
}

View File

@ -44,8 +44,11 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_ENSURE_EQ(context, NumInputs(node), 1);
TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
const TfLiteTensor* input = GetInput(context, node, kInputTensor);
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
const TfLiteTensor* input;
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputTensor, &input));
TfLiteTensor* output;
TF_LITE_ENSURE_OK(context,
GetOutputSafe(context, node, kOutputTensor, &output));
TF_LITE_ENSURE(context, NumDimensions(input) <= 4);
@ -74,8 +77,11 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
template <KernelType kernel_type>
TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
const TfLiteTensor* input = GetInput(context, node, kInputTensor);
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
const TfLiteTensor* input;
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputTensor, &input));
TfLiteTensor* output;
TF_LITE_ENSURE_OK(context,
GetOutputSafe(context, node, kOutputTensor, &output));
// TODO(b/143912164): instead of hardcode the epsilon here, we should read it
// from tensorflow, i.e., adding a params.

View File

@ -39,8 +39,11 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_ENSURE_EQ(context, NumInputs(node), 1);
TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
const TfLiteTensor* input = GetInput(context, node, kInputTensor);
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
const TfLiteTensor* input;
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputTensor, &input));
TfLiteTensor* output;
TF_LITE_ENSURE_OK(context,
GetOutputSafe(context, node, kOutputTensor, &output));
TF_LITE_ENSURE_EQ(context, NumDimensions(input), 4);
@ -61,8 +64,11 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
auto* params =
reinterpret_cast<TfLiteLocalResponseNormParams*>(node->builtin_data);
const TfLiteTensor* input = GetInput(context, node, kInputTensor);
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
const TfLiteTensor* input;
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputTensor, &input));
TfLiteTensor* output;
TF_LITE_ENSURE_OK(context,
GetOutputSafe(context, node, kOutputTensor, &output));
if (output->type == kTfLiteFloat32) {
#define TF_LITE_LOCAL_RESPONSE_NORM(type) \

View File

@ -54,9 +54,15 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
// Reinterprete the opaque data provided by user.
OpData* data = reinterpret_cast<OpData*>(node->user_data);
const TfLiteTensor* input1 = GetInput(context, node, kInputTensor1);
const TfLiteTensor* input2 = GetInput(context, node, kInputTensor2);
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
const TfLiteTensor* input1;
TF_LITE_ENSURE_OK(context,
GetInputSafe(context, node, kInputTensor1, &input1));
const TfLiteTensor* input2;
TF_LITE_ENSURE_OK(context,
GetInputSafe(context, node, kInputTensor2, &input2));
TfLiteTensor* output;
TF_LITE_ENSURE_OK(context,
GetOutputSafe(context, node, kOutputTensor, &output));
TF_LITE_ENSURE_TYPES_EQ(context, input1->type, input2->type);
@ -84,9 +90,15 @@ TfLiteStatus LogicalImpl(TfLiteContext* context, TfLiteNode* node,
bool (*func)(bool, bool)) {
OpData* data = reinterpret_cast<OpData*>(node->user_data);
const TfLiteTensor* input1 = GetInput(context, node, kInputTensor1);
const TfLiteTensor* input2 = GetInput(context, node, kInputTensor2);
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
const TfLiteTensor* input1;
TF_LITE_ENSURE_OK(context,
GetInputSafe(context, node, kInputTensor1, &input1));
const TfLiteTensor* input2;
TF_LITE_ENSURE_OK(context,
GetInputSafe(context, node, kInputTensor2, &input2));
TfLiteTensor* output;
TF_LITE_ENSURE_OK(context,
GetOutputSafe(context, node, kOutputTensor, &output));
if (data->requires_broadcast) {
reference_ops::BroadcastBinaryFunction4DSlow<bool, bool, bool>(

View File

@ -73,22 +73,26 @@ TfLiteStatus Resize(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_ENSURE(context, NumInputs(node) == 2 || NumInputs(node) == 3);
TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
const TfLiteTensor* hash = GetInput(context, node, 0);
const TfLiteTensor* hash;
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 0, &hash));
TF_LITE_ENSURE_EQ(context, NumDimensions(hash), 2);
// Support up to 32 bits.
TF_LITE_ENSURE(context, SizeOfDimension(hash, 1) <= 32);
const TfLiteTensor* input = GetInput(context, node, 1);
const TfLiteTensor* input;
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 1, &input));
TF_LITE_ENSURE(context, NumDimensions(input) >= 1);
if (NumInputs(node) == 3) {
const TfLiteTensor* weight = GetInput(context, node, 2);
const TfLiteTensor* weight;
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 2, &weight));
TF_LITE_ENSURE_EQ(context, NumDimensions(weight), 1);
TF_LITE_ENSURE_EQ(context, SizeOfDimension(weight, 0),
SizeOfDimension(input, 0));
}
TfLiteTensor* output = GetOutput(context, node, 0);
TfLiteTensor* output;
TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, 0, &output));
TfLiteIntArray* outputSize = TfLiteIntArrayCreate(1);
switch (params->type) {
case kTfLiteLshProjectionSparse:
@ -170,9 +174,13 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
auto* params =
reinterpret_cast<TfLiteLSHProjectionParams*>(node->builtin_data);
int32_t* out_buf = GetOutput(context, node, 0)->data.i32;
const TfLiteTensor* hash = GetInput(context, node, 0);
const TfLiteTensor* input = GetInput(context, node, 1);
TfLiteTensor* out_tensor;
TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, 0, &out_tensor));
int32_t* out_buf = out_tensor->data.i32;
const TfLiteTensor* hash;
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 0, &hash));
const TfLiteTensor* input;
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 1, &input));
const TfLiteTensor* weight =
NumInputs(node) == 2 ? nullptr : GetInput(context, node, 2);

View File

@ -149,7 +149,9 @@ TfLiteStatus PopulateQuantizedLstmParams8x8_16(
const TfLiteTensor* cell_state =
GetVariableInput(context, node, kCellStateTensor);
TF_LITE_ENSURE(context, cell_state != nullptr);
const TfLiteTensor* output_tensor = GetOutput(context, node, kOutputTensor);
TfLiteTensor* output_tensor;
TF_LITE_ENSURE_OK(
context, GetOutputSafe(context, node, kOutputTensor, &output_tensor));
auto* cell_state_params =
static_cast<TfLiteAffineQuantization*>(cell_state->quantization.params);
@ -173,25 +175,38 @@ TfLiteStatus PopulateQuantizedLstmParams8x8_16(
OpData* op_data = static_cast<OpData*>(node->user_data);
const bool use_layer_norm = op_data->use_layer_norm;
const TfLiteTensor* input = GetInput(context, node, kInputTensor);
const TfLiteTensor* input;
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputTensor, &input));
const TfLiteTensor* input_to_input_weights =
GetOptionalInputTensor(context, node, kInputToInputWeightsTensor);
const TfLiteTensor* input_to_forget_weights =
GetInput(context, node, kInputToForgetWeightsTensor);
const TfLiteTensor* input_to_cell_weights =
GetInput(context, node, kInputToCellWeightsTensor);
const TfLiteTensor* input_to_output_weights =
GetInput(context, node, kInputToOutputWeightsTensor);
const TfLiteTensor* input_to_forget_weights;
TF_LITE_ENSURE_OK(context,
GetInputSafe(context, node, kInputToForgetWeightsTensor,
&input_to_forget_weights));
const TfLiteTensor* input_to_cell_weights;
TF_LITE_ENSURE_OK(context,
GetInputSafe(context, node, kInputToCellWeightsTensor,
&input_to_cell_weights));
const TfLiteTensor* input_to_output_weights;
TF_LITE_ENSURE_OK(context,
GetInputSafe(context, node, kInputToOutputWeightsTensor,
&input_to_output_weights));
const TfLiteTensor* recurrent_to_input_weights =
GetOptionalInputTensor(context, node, kRecurrentToInputWeightsTensor);
const TfLiteTensor* recurrent_to_forget_weights =
GetInput(context, node, kRecurrentToForgetWeightsTensor);
const TfLiteTensor* recurrent_to_cell_weights =
GetInput(context, node, kRecurrentToCellWeightsTensor);
const TfLiteTensor* recurrent_to_output_weights =
GetInput(context, node, kRecurrentToOutputWeightsTensor);
const TfLiteTensor* recurrent_to_forget_weights;
TF_LITE_ENSURE_OK(context,
GetInputSafe(context, node, kRecurrentToForgetWeightsTensor,
&recurrent_to_forget_weights));
const TfLiteTensor* recurrent_to_cell_weights;
TF_LITE_ENSURE_OK(context,
GetInputSafe(context, node, kRecurrentToCellWeightsTensor,
&recurrent_to_cell_weights));
const TfLiteTensor* recurrent_to_output_weights;
TF_LITE_ENSURE_OK(context,
GetInputSafe(context, node, kRecurrentToOutputWeightsTensor,
&recurrent_to_output_weights));
const TfLiteTensor* cell_to_input_weights =
GetOptionalInputTensor(context, node, kCellToInputWeightsTensor);
@ -227,7 +242,9 @@ TfLiteStatus PopulateQuantizedLstmParams8x8_16(
std::vector<int32> intermediate_zp;
for (int i = 0; i < 4; ++i) {
if (use_layer_norm) {
const TfLiteTensor* intermediate = GetIntermediates(context, node, i);
TfLiteTensor* intermediate;
TF_LITE_ENSURE_OK(context,
GetIntermediatesSafe(context, node, i, &intermediate));
auto* params = static_cast<TfLiteAffineQuantization*>(
intermediate->quantization.params);
intermediate_scale.push_back(params->scale->data[0]);
@ -240,7 +257,8 @@ TfLiteStatus PopulateQuantizedLstmParams8x8_16(
}
// In the absense of projection, hidden becomes otuput and this intermediate
// is ignored.
const TfLiteTensor* hidden = GetIntermediates(context, node, 4);
TfLiteTensor* hidden;
TF_LITE_ENSURE_OK(context, GetIntermediatesSafe(context, node, 4, &hidden));
auto* hidden_params =
static_cast<TfLiteAffineQuantization*>(hidden->quantization.params);
intermediate_scale.push_back(hidden_params->scale->data[0]);
@ -446,24 +464,37 @@ TfLiteStatus PopulateQuantizedLstmParams8x8_8(
TfLiteContext* context, TfLiteNode* node,
lstm_eval::IntegerLstmParameter* integer_lstm_param) {
// Get all tensors.
const TfLiteTensor* input = GetInput(context, node, kInputTensor);
const TfLiteTensor* input;
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputTensor, &input));
const TfLiteTensor* input_to_input_weights =
GetOptionalInputTensor(context, node, kInputToInputWeightsTensor);
const TfLiteTensor* input_to_forget_weights =
GetInput(context, node, kInputToForgetWeightsTensor);
const TfLiteTensor* input_to_cell_weights =
GetInput(context, node, kInputToCellWeightsTensor);
const TfLiteTensor* input_to_output_weights =
GetInput(context, node, kInputToOutputWeightsTensor);
const TfLiteTensor* input_to_forget_weights;
TF_LITE_ENSURE_OK(context,
GetInputSafe(context, node, kInputToForgetWeightsTensor,
&input_to_forget_weights));
const TfLiteTensor* input_to_cell_weights;
TF_LITE_ENSURE_OK(context,
GetInputSafe(context, node, kInputToCellWeightsTensor,
&input_to_cell_weights));
const TfLiteTensor* input_to_output_weights;
TF_LITE_ENSURE_OK(context,
GetInputSafe(context, node, kInputToOutputWeightsTensor,
&input_to_output_weights));
const TfLiteTensor* recurrent_to_input_weights =
GetOptionalInputTensor(context, node, kRecurrentToInputWeightsTensor);
const TfLiteTensor* recurrent_to_forget_weights =
GetInput(context, node, kRecurrentToForgetWeightsTensor);
const TfLiteTensor* recurrent_to_cell_weights =
GetInput(context, node, kRecurrentToCellWeightsTensor);
const TfLiteTensor* recurrent_to_output_weights =
GetInput(context, node, kRecurrentToOutputWeightsTensor);
const TfLiteTensor* recurrent_to_forget_weights;
TF_LITE_ENSURE_OK(context,
GetInputSafe(context, node, kRecurrentToForgetWeightsTensor,
&recurrent_to_forget_weights));
const TfLiteTensor* recurrent_to_cell_weights;
TF_LITE_ENSURE_OK(context,
GetInputSafe(context, node, kRecurrentToCellWeightsTensor,
&recurrent_to_cell_weights));
const TfLiteTensor* recurrent_to_output_weights;
TF_LITE_ENSURE_OK(context,
GetInputSafe(context, node, kRecurrentToOutputWeightsTensor,
&recurrent_to_output_weights));
const TfLiteTensor* cell_to_input_weights =
GetOptionalInputTensor(context, node, kCellToInputWeightsTensor);
@ -483,12 +514,15 @@ TfLiteStatus PopulateQuantizedLstmParams8x8_8(
const TfLiteTensor* input_gate_bias =
GetOptionalInputTensor(context, node, kInputGateBiasTensor);
const TfLiteTensor* forget_gate_bias =
GetInput(context, node, kForgetGateBiasTensor);
const TfLiteTensor* cell_gate_bias =
GetInput(context, node, kCellGateBiasTensor);
const TfLiteTensor* output_gate_bias =
GetInput(context, node, kOutputGateBiasTensor);
const TfLiteTensor* forget_gate_bias;
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kForgetGateBiasTensor,
&forget_gate_bias));
const TfLiteTensor* cell_gate_bias;
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kCellGateBiasTensor,
&cell_gate_bias));
const TfLiteTensor* output_gate_bias;
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kOutputGateBiasTensor,
&output_gate_bias));
const TfLiteTensor* projection_weights =
GetOptionalInputTensor(context, node, kProjectionWeightsTensor);
@ -774,7 +808,9 @@ TfLiteStatus PopulateQuantizedLstmParams8x8_8(
const float cell_clip = params->cell_clip;
const float proj_clip = params->proj_clip;
const TfLiteTensor* output_tensor = GetOutput(context, node, kOutputTensor);
TfLiteTensor* output_tensor;
TF_LITE_ENSURE_OK(
context, GetOutputSafe(context, node, kOutputTensor, &output_tensor));
auto* cell_state_params = reinterpret_cast<TfLiteAffineQuantization*>(
cell_state->quantization.params);
@ -825,8 +861,10 @@ TfLiteStatus CheckInputTensorDimensions(TfLiteContext* context,
TF_LITE_ENSURE(context, params->cell_clip >= 0);
TF_LITE_ENSURE(context, params->proj_clip >= 0);
const TfLiteTensor* input_to_forget_weights =
GetInput(context, node, kInputToForgetWeightsTensor);
const TfLiteTensor* input_to_forget_weights;
TF_LITE_ENSURE_OK(context,
GetInputSafe(context, node, kInputToForgetWeightsTensor,
&input_to_forget_weights));
TF_LITE_ENSURE_EQ(context, input_to_forget_weights->dims->size, 2);
TF_LITE_ENSURE_EQ(context, input_to_forget_weights->dims->data[0], n_cell);
TF_LITE_ENSURE_EQ(context, input_to_forget_weights->dims->data[1], n_input);
@ -845,8 +883,10 @@ TfLiteStatus CheckInputTensorDimensions(TfLiteContext* context,
input_to_forget_weights->type);
}
const TfLiteTensor* input_to_cell_weights =
GetInput(context, node, kInputToCellWeightsTensor);
const TfLiteTensor* input_to_cell_weights;
TF_LITE_ENSURE_OK(context,
GetInputSafe(context, node, kInputToCellWeightsTensor,
&input_to_cell_weights));
TF_LITE_ENSURE_EQ(context, input_to_cell_weights->dims->size, 2);
TF_LITE_ENSURE_EQ(context, input_to_cell_weights->dims->data[0], n_cell);
TF_LITE_ENSURE_EQ(context, input_to_cell_weights->dims->data[1], n_input);
@ -865,8 +905,10 @@ TfLiteStatus CheckInputTensorDimensions(TfLiteContext* context,
input_to_forget_weights->type);
}
const TfLiteTensor* recurrent_to_forget_weights =
GetInput(context, node, kRecurrentToForgetWeightsTensor);
const TfLiteTensor* recurrent_to_forget_weights;
TF_LITE_ENSURE_OK(context,
GetInputSafe(context, node, kRecurrentToForgetWeightsTensor,
&recurrent_to_forget_weights));
TF_LITE_ENSURE_EQ(context, recurrent_to_forget_weights->dims->size, 2);
TF_LITE_ENSURE_EQ(context, recurrent_to_forget_weights->dims->data[0],
n_cell);
@ -875,8 +917,10 @@ TfLiteStatus CheckInputTensorDimensions(TfLiteContext* context,
TF_LITE_ENSURE_TYPES_EQ(context, recurrent_to_forget_weights->type,
input_to_forget_weights->type);
const TfLiteTensor* recurrent_to_cell_weights =
GetInput(context, node, kRecurrentToCellWeightsTensor);
const TfLiteTensor* recurrent_to_cell_weights;
TF_LITE_ENSURE_OK(context,
GetInputSafe(context, node, kRecurrentToCellWeightsTensor,
&recurrent_to_cell_weights));
TF_LITE_ENSURE_EQ(context, recurrent_to_cell_weights->dims->size, 2);
TF_LITE_ENSURE_EQ(context, recurrent_to_cell_weights->dims->data[0], n_cell);
TF_LITE_ENSURE_EQ(context, recurrent_to_cell_weights->dims->data[1],
@ -948,8 +992,9 @@ TfLiteStatus CheckInputTensorDimensions(TfLiteContext* context,
}
}
const TfLiteTensor* forget_gate_bias =
GetInput(context, node, kForgetGateBiasTensor);
const TfLiteTensor* forget_gate_bias;
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kForgetGateBiasTensor,
&forget_gate_bias));
TF_LITE_ENSURE_EQ(context, forget_gate_bias->dims->size, 1);
TF_LITE_ENSURE_EQ(context, forget_gate_bias->dims->data[0], n_cell);
if (is_integer) {
@ -958,8 +1003,9 @@ TfLiteStatus CheckInputTensorDimensions(TfLiteContext* context,
TF_LITE_ENSURE_TYPES_EQ(context, forget_gate_bias->type, kTfLiteFloat32);
}
const TfLiteTensor* cell_gate_bias =
GetInput(context, node, kCellGateBiasTensor);
const TfLiteTensor* cell_gate_bias;
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kCellGateBiasTensor,
&cell_gate_bias));
TF_LITE_ENSURE_EQ(context, cell_gate_bias->dims->size, 1);
TF_LITE_ENSURE_EQ(context, cell_gate_bias->dims->data[0], n_cell);
if (is_integer) {
@ -968,8 +1014,9 @@ TfLiteStatus CheckInputTensorDimensions(TfLiteContext* context,
TF_LITE_ENSURE_TYPES_EQ(context, cell_gate_bias->type, kTfLiteFloat32);
}
const TfLiteTensor* output_gate_bias =
GetInput(context, node, kOutputGateBiasTensor);
const TfLiteTensor* output_gate_bias;
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kOutputGateBiasTensor,
&output_gate_bias));
TF_LITE_ENSURE_EQ(context, output_gate_bias->dims->size, 1);
TF_LITE_ENSURE_EQ(context, output_gate_bias->dims->data[0], n_cell);
if (is_integer) {
@ -1105,7 +1152,8 @@ TfLiteStatus PrecomputeZeroPointTimesWeightWithBias(
TfLiteStatus PopulatePrecomputedZPTimesWeightsWithBias(TfLiteContext* context,
OpData* op_data,
TfLiteNode* node) {
const TfLiteTensor* input = GetInput(context, node, kInputTensor);
const TfLiteTensor* input;
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputTensor, &input));
const TfLiteTensor* output_state =
GetVariableInput(context, node, kOutputStateTensor);
TF_LITE_ENSURE(context, output_state != nullptr);
@ -1115,21 +1163,33 @@ TfLiteStatus PopulatePrecomputedZPTimesWeightsWithBias(TfLiteContext* context,
const TfLiteTensor* input_to_input_weights =
GetOptionalInputTensor(context, node, kInputToInputWeightsTensor);
const TfLiteTensor* input_to_forget_weights =
GetInput(context, node, kInputToForgetWeightsTensor);
const TfLiteTensor* input_to_cell_weights =
GetInput(context, node, kInputToCellWeightsTensor);
const TfLiteTensor* input_to_output_weights =
GetInput(context, node, kInputToOutputWeightsTensor);
const TfLiteTensor* input_to_forget_weights;
TF_LITE_ENSURE_OK(context,
GetInputSafe(context, node, kInputToForgetWeightsTensor,
&input_to_forget_weights));
const TfLiteTensor* input_to_cell_weights;
TF_LITE_ENSURE_OK(context,
GetInputSafe(context, node, kInputToCellWeightsTensor,
&input_to_cell_weights));
const TfLiteTensor* input_to_output_weights;
TF_LITE_ENSURE_OK(context,
GetInputSafe(context, node, kInputToOutputWeightsTensor,
&input_to_output_weights));
const TfLiteTensor* recurrent_to_input_weights =
GetOptionalInputTensor(context, node, kRecurrentToInputWeightsTensor);
const TfLiteTensor* recurrent_to_forget_weights =
GetInput(context, node, kRecurrentToForgetWeightsTensor);
const TfLiteTensor* recurrent_to_cell_weights =
GetInput(context, node, kRecurrentToCellWeightsTensor);
const TfLiteTensor* recurrent_to_output_weights =
GetInput(context, node, kRecurrentToOutputWeightsTensor);
const TfLiteTensor* recurrent_to_forget_weights;
TF_LITE_ENSURE_OK(context,
GetInputSafe(context, node, kRecurrentToForgetWeightsTensor,
&recurrent_to_forget_weights));
const TfLiteTensor* recurrent_to_cell_weights;
TF_LITE_ENSURE_OK(context,
GetInputSafe(context, node, kRecurrentToCellWeightsTensor,
&recurrent_to_cell_weights));
const TfLiteTensor* recurrent_to_output_weights;
TF_LITE_ENSURE_OK(context,
GetInputSafe(context, node, kRecurrentToOutputWeightsTensor,
&recurrent_to_output_weights));
const TfLiteTensor* projection_weights =
GetOptionalInputTensor(context, node, kProjectionWeightsTensor);
@ -1254,20 +1314,25 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
// Inferring batch size, number of outputs and number of cells from the
// input tensors.
const TfLiteTensor* input = GetInput(context, node, kInputTensor);
const TfLiteTensor* input;
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputTensor, &input));
const bool is_integer = input->type == kTfLiteInt8;
TF_LITE_ENSURE(context, input->dims->size > 1);
const int n_batch = input->dims->data[0];
const int n_input = input->dims->data[1];
const TfLiteTensor* input_to_output_weights =
GetInput(context, node, kInputToOutputWeightsTensor);
const TfLiteTensor* input_to_output_weights;
TF_LITE_ENSURE_OK(context,
GetInputSafe(context, node, kInputToOutputWeightsTensor,
&input_to_output_weights));
const int n_cell = input_to_output_weights->dims->data[0];
TF_LITE_ENSURE_EQ(context, input_to_output_weights->dims->size, 2);
TF_LITE_ENSURE_EQ(context, input_to_output_weights->dims->data[1], n_input);
const TfLiteTensor* recurrent_to_output_weights =
GetInput(context, node, kRecurrentToOutputWeightsTensor);
const TfLiteTensor* recurrent_to_output_weights;
TF_LITE_ENSURE_OK(context,
GetInputSafe(context, node, kRecurrentToOutputWeightsTensor,
&recurrent_to_output_weights));
TF_LITE_ENSURE_EQ(context, recurrent_to_output_weights->dims->size, 2);
TF_LITE_ENSURE_EQ(context, recurrent_to_output_weights->dims->data[0],
n_cell);
@ -1279,7 +1344,9 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
n_cell, use_layer_norm, is_integer));
// Get the pointer to output, output_state and cell_state tensors.
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
TfLiteTensor* output;
TF_LITE_ENSURE_OK(context,
GetOutputSafe(context, node, kOutputTensor, &output));
TfLiteTensor* output_state =
GetVariableInput(context, node, kOutputStateTensor);
@ -1339,7 +1406,9 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
if (!is_integer) {
node->temporaries->data[kScratchBuffer] =
op_data->scratch_tensor_index + kScratchBuffer;
TfLiteTensor* scratch_buffer = GetTemporary(context, node, kScratchBuffer);
TfLiteTensor* scratch_buffer;
TF_LITE_ENSURE_OK(context, GetTemporarySafe(context, node, kScratchBuffer,
&scratch_buffer));
scratch_buffer->type = input->type;
scratch_buffer->allocation_type = kTfLiteArenaRw;
@ -1367,8 +1436,9 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
// output_state and cell_state tensors.
node->temporaries->data[kInputQuantized] =
op_data->scratch_tensor_index + kInputQuantized;
TfLiteTensor* input_quantized =
GetTemporary(context, node, kInputQuantized);
TfLiteTensor* input_quantized;
TF_LITE_ENSURE_OK(context, GetTemporarySafe(context, node, kInputQuantized,
&input_quantized));
input_quantized->type = input_to_output_weights->type;
input_quantized->allocation_type = kTfLiteArenaRw;
if (!TfLiteIntArrayEqual(input_quantized->dims, input->dims)) {
@ -1378,8 +1448,10 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
}
node->temporaries->data[kOutputStateQuantized] =
op_data->scratch_tensor_index + kOutputStateQuantized;
TfLiteTensor* output_state_quantized =
GetTemporary(context, node, kOutputStateQuantized);
TfLiteTensor* output_state_quantized;
TF_LITE_ENSURE_OK(context,
GetTemporarySafe(context, node, kOutputStateQuantized,
&output_state_quantized));
output_state_quantized->type = input_to_output_weights->type;
output_state_quantized->allocation_type = kTfLiteArenaRw;
if (!TfLiteIntArrayEqual(output_state_quantized->dims,
@ -1392,8 +1464,10 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
}
node->temporaries->data[kCellStateQuantized] =
op_data->scratch_tensor_index + kCellStateQuantized;
TfLiteTensor* cell_state_quantized =
GetTemporary(context, node, kCellStateQuantized);
TfLiteTensor* cell_state_quantized;
TF_LITE_ENSURE_OK(context,
GetTemporarySafe(context, node, kCellStateQuantized,
&cell_state_quantized));
cell_state_quantized->type = input_to_output_weights->type;
cell_state_quantized->allocation_type = kTfLiteArenaRw;
if (!TfLiteIntArrayEqual(cell_state_quantized->dims, cell_state->dims)) {
@ -1410,7 +1484,10 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
// the scaling factor of the matrix).
node->temporaries->data[kInputScalingFactors] =
op_data->scratch_tensor_index + kInputScalingFactors;
TfLiteTensor* input_sf = GetTemporary(context, node, kInputScalingFactors);
TfLiteTensor* input_sf;
TF_LITE_ENSURE_OK(
context,
GetTemporarySafe(context, node, kInputScalingFactors, &input_sf));
input_sf->type = kTfLiteFloat32;
input_sf->allocation_type = kTfLiteArenaRw;
int scaling_dims[1] = {n_batch};
@ -1422,8 +1499,10 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
}
node->temporaries->data[kOutputStateScalingFactors] =
op_data->scratch_tensor_index + kOutputStateScalingFactors;
TfLiteTensor* output_state_sf =
GetTemporary(context, node, kOutputStateScalingFactors);
TfLiteTensor* output_state_sf;
TF_LITE_ENSURE_OK(
context, GetTemporarySafe(context, node, kOutputStateScalingFactors,
&output_state_sf));
output_state_sf->type = kTfLiteFloat32;
output_state_sf->allocation_type = kTfLiteArenaRw;
if (!TfLiteIntArrayEqualsArray(output_state_sf->dims, 1, scaling_dims)) {
@ -1434,8 +1513,10 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
}
node->temporaries->data[kProductScalingFactors] =
op_data->scratch_tensor_index + kProductScalingFactors;
TfLiteTensor* prod_scaling_factors =
GetTemporary(context, node, kProductScalingFactors);
TfLiteTensor* prod_scaling_factors;
TF_LITE_ENSURE_OK(context,
GetTemporarySafe(context, node, kProductScalingFactors,
&prod_scaling_factors));
prod_scaling_factors->type = kTfLiteFloat32;
prod_scaling_factors->allocation_type = kTfLiteArenaRw;
if (!TfLiteIntArrayEqualsArray(prod_scaling_factors->dims, 1,
@ -1451,8 +1532,10 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
// this is used for diagonal matrices, only need to store n_cell values.
node->temporaries->data[kRecoveredCellWeights] =
op_data->scratch_tensor_index + kRecoveredCellWeights;
TfLiteTensor* recovered_cell_weights =
GetTemporary(context, node, kRecoveredCellWeights);
TfLiteTensor* recovered_cell_weights;
TF_LITE_ENSURE_OK(context,
GetTemporarySafe(context, node, kRecoveredCellWeights,
&recovered_cell_weights));
recovered_cell_weights->type = kTfLiteFloat32;
recovered_cell_weights->allocation_type = kTfLiteArenaRw;
int recovered_cell_dims[1] = {n_cell};
@ -1468,7 +1551,9 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
// multiplication before multiplication by scaling factor
node->temporaries->data[kAccumScratch] =
op_data->scratch_tensor_index + kAccumScratch;
TfLiteTensor* accum_scratch = GetTemporary(context, node, kAccumScratch);
TfLiteTensor* accum_scratch;
TF_LITE_ENSURE_OK(context, GetTemporarySafe(context, node, kAccumScratch,
&accum_scratch));
accum_scratch->type = kTfLiteInt32;
accum_scratch->allocation_type = kTfLiteArenaRw;
int accum_scratch_dims[2] = {n_cell, n_batch};
@ -1482,7 +1567,9 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
}
node->temporaries->data[kInputZeroPoints] =
op_data->scratch_tensor_index + kInputZeroPoints;
TfLiteTensor* input_zp = GetTemporary(context, node, kInputZeroPoints);
TfLiteTensor* input_zp;
TF_LITE_ENSURE_OK(
context, GetTemporarySafe(context, node, kInputZeroPoints, &input_zp));
input_zp->type = kTfLiteFloat32;
input_zp->allocation_type = kTfLiteArenaRw;
if (!TfLiteIntArrayEqualsArray(input_zp->dims, 1, scaling_dims)) {
@ -1493,8 +1580,10 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
}
node->temporaries->data[kOutputStateZeroPoints] =
op_data->scratch_tensor_index + kOutputStateZeroPoints;
TfLiteTensor* output_state_zp =
GetTemporary(context, node, kOutputStateZeroPoints);
TfLiteTensor* output_state_zp;
TF_LITE_ENSURE_OK(context,
GetTemporarySafe(context, node, kOutputStateZeroPoints,
&output_state_zp));
output_state_zp->type = kTfLiteFloat32;
output_state_zp->allocation_type = kTfLiteArenaRw;
if (!TfLiteIntArrayEqualsArray(output_state_zp->dims, 1, scaling_dims)) {
@ -1516,7 +1605,9 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
row_sums_rows += ceil(static_cast<float>(n_output) / n_cell);
}
TfLiteTensor* row_sums = GetTemporary(context, node, kRowSums);
TfLiteTensor* row_sums;
TF_LITE_ENSURE_OK(context,
GetTemporarySafe(context, node, kRowSums, &row_sums));
row_sums->type = kTfLiteInt32;
row_sums->allocation_type = kTfLiteArenaRwPersistent;
const int row_sums_dims[2] = {row_sums_rows, n_cell};
@ -1664,8 +1755,10 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
for (int scratch_index = 0; scratch_index < 6; ++scratch_index) {
node->temporaries->data[scratch_index] =
op_data->scratch_tensor_index + scratch_index;
TfLiteTensor* scratch_tensor =
GetTemporary(context, node, scratch_index);
TfLiteTensor* scratch_tensor;
TF_LITE_ENSURE_OK(
context,
GetTemporarySafe(context, node, scratch_index, &scratch_tensor));
scratch_tensor->type = kTfLiteInt16;
if (scratch_index == 4) {
scratch_tensor->type = kTfLiteInt8;
@ -1701,8 +1794,10 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
for (int scratch_index = 0; scratch_index < 8; ++scratch_index) {
node->temporaries->data[scratch_index] =
op_data->scratch_tensor_index + scratch_index;
TfLiteTensor* scratch_tensor =
GetTemporary(context, node, scratch_index);
TfLiteTensor* scratch_tensor;
TF_LITE_ENSURE_OK(
context,
GetTemporarySafe(context, node, scratch_index, &scratch_tensor));
if (scratch_index == 0 || scratch_index == 1) {
scratch_tensor->type = kTfLiteInt8;
} else {
@ -1731,25 +1826,38 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
const auto* params = static_cast<TfLiteLSTMParams*>(node->builtin_data);
OpData* op_data = static_cast<OpData*>(node->user_data);
const TfLiteTensor* input = GetInput(context, node, kInputTensor);
const TfLiteTensor* input;
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputTensor, &input));
const TfLiteTensor* input_to_input_weights =
GetOptionalInputTensor(context, node, kInputToInputWeightsTensor);
const TfLiteTensor* input_to_forget_weights =
GetInput(context, node, kInputToForgetWeightsTensor);
const TfLiteTensor* input_to_cell_weights =
GetInput(context, node, kInputToCellWeightsTensor);
const TfLiteTensor* input_to_output_weights =
GetInput(context, node, kInputToOutputWeightsTensor);
const TfLiteTensor* input_to_forget_weights;
TF_LITE_ENSURE_OK(context,
GetInputSafe(context, node, kInputToForgetWeightsTensor,
&input_to_forget_weights));
const TfLiteTensor* input_to_cell_weights;
TF_LITE_ENSURE_OK(context,
GetInputSafe(context, node, kInputToCellWeightsTensor,
&input_to_cell_weights));
const TfLiteTensor* input_to_output_weights;
TF_LITE_ENSURE_OK(context,
GetInputSafe(context, node, kInputToOutputWeightsTensor,
&input_to_output_weights));
const TfLiteTensor* recurrent_to_input_weights =
GetOptionalInputTensor(context, node, kRecurrentToInputWeightsTensor);
const TfLiteTensor* recurrent_to_forget_weights =
GetInput(context, node, kRecurrentToForgetWeightsTensor);
const TfLiteTensor* recurrent_to_cell_weights =
GetInput(context, node, kRecurrentToCellWeightsTensor);
const TfLiteTensor* recurrent_to_output_weights =
GetInput(context, node, kRecurrentToOutputWeightsTensor);
const TfLiteTensor* recurrent_to_forget_weights;
TF_LITE_ENSURE_OK(context,
GetInputSafe(context, node, kRecurrentToForgetWeightsTensor,
&recurrent_to_forget_weights));
const TfLiteTensor* recurrent_to_cell_weights;
TF_LITE_ENSURE_OK(context,
GetInputSafe(context, node, kRecurrentToCellWeightsTensor,
&recurrent_to_cell_weights));
const TfLiteTensor* recurrent_to_output_weights;
TF_LITE_ENSURE_OK(context,
GetInputSafe(context, node, kRecurrentToOutputWeightsTensor,
&recurrent_to_output_weights));
const TfLiteTensor* cell_to_input_weights =
GetOptionalInputTensor(context, node, kCellToInputWeightsTensor);
@ -1769,12 +1877,15 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
const TfLiteTensor* input_gate_bias =
GetOptionalInputTensor(context, node, kInputGateBiasTensor);
const TfLiteTensor* forget_gate_bias =
GetInput(context, node, kForgetGateBiasTensor);
const TfLiteTensor* cell_gate_bias =
GetInput(context, node, kCellGateBiasTensor);
const TfLiteTensor* output_gate_bias =
GetInput(context, node, kOutputGateBiasTensor);
const TfLiteTensor* forget_gate_bias;
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kForgetGateBiasTensor,
&forget_gate_bias));
const TfLiteTensor* cell_gate_bias;
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kCellGateBiasTensor,
&cell_gate_bias));
const TfLiteTensor* output_gate_bias;
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kOutputGateBiasTensor,
&output_gate_bias));
const TfLiteTensor* projection_weights =
GetOptionalInputTensor(context, node, kProjectionWeightsTensor);
@ -1783,16 +1894,20 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
TfLiteTensor* output_state =
GetVariableInput(context, node, kOutputStateTensor);
TF_LITE_ENSURE(context, output_state != nullptr);
TFLITE_DCHECK(output_state != nullptr);
TfLiteTensor* cell_state = GetVariableInput(context, node, kCellStateTensor);
TF_LITE_ENSURE(context, cell_state != nullptr);
TFLITE_DCHECK(cell_state != nullptr);
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
TfLiteTensor* output;
TF_LITE_ENSURE_OK(context,
GetOutputSafe(context, node, kOutputTensor, &output));
switch (input_to_output_weights->type) {
case kTfLiteFloat32: {
// Index the scratch buffers pointers to the global scratch buffer.
TfLiteTensor* scratch_buffer = GetTemporary(context, node, 0);
TfLiteTensor* scratch_buffer;
TF_LITE_ENSURE_OK(context,
GetTemporarySafe(context, node, 0, &scratch_buffer));
return lstm_eval::EvalFloat(
input, input_to_input_weights, input_to_forget_weights,
input_to_cell_weights, input_to_output_weights,
@ -1818,7 +1933,9 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
const bool is_hybrid = (input->type == kTfLiteFloat32);
const bool is_sparse = input_to_output_weights->sparsity != nullptr;
if (is_hybrid) {
TfLiteTensor* row_sums = GetTemporary(context, node, kRowSums);
TfLiteTensor* row_sums;
TF_LITE_ENSURE_OK(context,
GetTemporarySafe(context, node, kRowSums, &row_sums));
const int row_sums_size = row_sums->dims->data[0];
if (is_sparse) {
TfLiteTensor* input_to_input_weights_ledger =
@ -1957,12 +2074,24 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
} else {
const int num_intermediate_tensors = node->intermediates->size;
if (num_intermediate_tensors == 5) {
TfLiteTensor* scratch0 = GetTemporary(context, node, 0);
TfLiteTensor* scratch1 = GetTemporary(context, node, 1);
TfLiteTensor* scratch2 = GetTemporary(context, node, 2);
TfLiteTensor* scratch3 = GetTemporary(context, node, 3);
TfLiteTensor* scratch4 = GetTemporary(context, node, 4);
TfLiteTensor* scratch5 = GetTemporary(context, node, 5);
TfLiteTensor* scratch0;
TF_LITE_ENSURE_OK(context,
GetTemporarySafe(context, node, 0, &scratch0));
TfLiteTensor* scratch1;
TF_LITE_ENSURE_OK(context,
GetTemporarySafe(context, node, 1, &scratch1));
TfLiteTensor* scratch2;
TF_LITE_ENSURE_OK(context,
GetTemporarySafe(context, node, 2, &scratch2));
TfLiteTensor* scratch3;
TF_LITE_ENSURE_OK(context,
GetTemporarySafe(context, node, 3, &scratch3));
TfLiteTensor* scratch4;
TF_LITE_ENSURE_OK(context,
GetTemporarySafe(context, node, 4, &scratch4));
TfLiteTensor* scratch5;
TF_LITE_ENSURE_OK(context,
GetTemporarySafe(context, node, 5, &scratch5));
return lstm_eval::EvalInteger8x8_16(
input, input_to_input_weights, input_to_forget_weights,
input_to_cell_weights, input_to_output_weights,
@ -1978,14 +2107,30 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
scratch3, scratch4, scratch5,
CpuBackendContext::GetFromContext(context));
} else {
TfLiteTensor* scratch0 = GetTemporary(context, node, 0);
TfLiteTensor* scratch1 = GetTemporary(context, node, 1);
TfLiteTensor* scratch2 = GetTemporary(context, node, 2);
TfLiteTensor* scratch3 = GetTemporary(context, node, 3);
TfLiteTensor* scratch4 = GetTemporary(context, node, 4);
TfLiteTensor* scratch5 = GetTemporary(context, node, 5);
TfLiteTensor* scratch6 = GetTemporary(context, node, 6);
TfLiteTensor* scratch7 = GetTemporary(context, node, 7);
TfLiteTensor* scratch0;
TF_LITE_ENSURE_OK(context,
GetTemporarySafe(context, node, 0, &scratch0));
TfLiteTensor* scratch1;
TF_LITE_ENSURE_OK(context,
GetTemporarySafe(context, node, 1, &scratch1));
TfLiteTensor* scratch2;
TF_LITE_ENSURE_OK(context,
GetTemporarySafe(context, node, 2, &scratch2));
TfLiteTensor* scratch3;
TF_LITE_ENSURE_OK(context,
GetTemporarySafe(context, node, 3, &scratch3));
TfLiteTensor* scratch4;
TF_LITE_ENSURE_OK(context,
GetTemporarySafe(context, node, 4, &scratch4));
TfLiteTensor* scratch5;
TF_LITE_ENSURE_OK(context,
GetTemporarySafe(context, node, 5, &scratch5));
TfLiteTensor* scratch6;
TF_LITE_ENSURE_OK(context,
GetTemporarySafe(context, node, 6, &scratch6));
TfLiteTensor* scratch7;
TF_LITE_ENSURE_OK(context,
GetTemporarySafe(context, node, 7, &scratch7));
return lstm_eval::EvalInteger8x8_8(
input, input_to_input_weights, input_to_forget_weights,
input_to_cell_weights, input_to_output_weights,
@ -2046,12 +2191,19 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_ENSURE(context, node->inputs->size == kInputNum);
TF_LITE_ENSURE(context, node->outputs->size == kOutputNum);
const TfLiteTensor* input = GetInput(context, node, kInputData);
const TfLiteTensor* prev_activation =
GetInput(context, node, kInputPrevActivation);
const TfLiteTensor* weights = GetInput(context, node, kInputWeights);
const TfLiteTensor* bias = GetInput(context, node, kInputBiases);
const TfLiteTensor* prev_state = GetInput(context, node, kInputPrevState);
const TfLiteTensor* input;
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputData, &input));
const TfLiteTensor* prev_activation;
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputPrevActivation,
&prev_activation));
const TfLiteTensor* weights;
TF_LITE_ENSURE_OK(context,
GetInputSafe(context, node, kInputWeights, &weights));
const TfLiteTensor* bias;
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputBiases, &bias));
const TfLiteTensor* prev_state;
TF_LITE_ENSURE_OK(context,
GetInputSafe(context, node, kInputPrevState, &prev_state));
TF_LITE_ENSURE_EQ(context, input->dims->size, 2);
const int num_batches = input->dims->data[0];
@ -2073,11 +2225,18 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_ENSURE_EQ(context, prev_state->dims->data[0], num_batches);
TF_LITE_ENSURE_EQ(context, prev_state->dims->data[1], activation_depth);
TfLiteTensor* activation_out = GetOutput(context, node, kOutputActivation);
TfLiteTensor* state_out = GetOutput(context, node, kOutputState);
TfLiteTensor* concat_temp = GetOutput(context, node, kOutputConcatTemp);
TfLiteTensor* activation_temp =
GetOutput(context, node, kOutputActivationTemp);
TfLiteTensor* activation_out;
TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, kOutputActivation,
&activation_out));
TfLiteTensor* state_out;
TF_LITE_ENSURE_OK(context,
GetOutputSafe(context, node, kOutputState, &state_out));
TfLiteTensor* concat_temp;
TF_LITE_ENSURE_OK(
context, GetOutputSafe(context, node, kOutputConcatTemp, &concat_temp));
TfLiteTensor* activation_temp;
TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, kOutputActivationTemp,
&activation_temp));
TF_LITE_ENSURE_OK(context, context->ResizeTensor(
context, activation_out,
@ -2106,18 +2265,32 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
}
TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
const TfLiteTensor* input = GetInput(context, node, kInputData);
const TfLiteTensor* prev_activation =
GetInput(context, node, kInputPrevActivation);
const TfLiteTensor* weights = GetInput(context, node, kInputWeights);
const TfLiteTensor* bias = GetInput(context, node, kInputBiases);
const TfLiteTensor* prev_state = GetInput(context, node, kInputPrevState);
const TfLiteTensor* input;
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputData, &input));
const TfLiteTensor* prev_activation;
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputPrevActivation,
&prev_activation));
const TfLiteTensor* weights;
TF_LITE_ENSURE_OK(context,
GetInputSafe(context, node, kInputWeights, &weights));
const TfLiteTensor* bias;
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputBiases, &bias));
const TfLiteTensor* prev_state;
TF_LITE_ENSURE_OK(context,
GetInputSafe(context, node, kInputPrevState, &prev_state));
TfLiteTensor* activation_out = GetOutput(context, node, kOutputActivation);
TfLiteTensor* state_out = GetOutput(context, node, kOutputState);
TfLiteTensor* concat_temp = GetOutput(context, node, kOutputConcatTemp);
TfLiteTensor* activation_temp =
GetOutput(context, node, kOutputActivationTemp);
TfLiteTensor* activation_out;
TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, kOutputActivation,
&activation_out));
TfLiteTensor* state_out;
TF_LITE_ENSURE_OK(context,
GetOutputSafe(context, node, kOutputState, &state_out));
TfLiteTensor* concat_temp;
TF_LITE_ENSURE_OK(
context, GetOutputSafe(context, node, kOutputConcatTemp, &concat_temp));
TfLiteTensor* activation_temp;
TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, kOutputActivationTemp,
&activation_temp));
if (input->type == kTfLiteFloat32 &&
prev_activation->type == kTfLiteFloat32 &&

View File

@ -32,12 +32,15 @@ constexpr int kOutputTensor = 0;
TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_ENSURE_EQ(context, NumInputs(node), 1);
TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
const TfLiteTensor* input = GetInput(context, node, kInputTensor);
const TfLiteTensor* input;
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputTensor, &input));
TfLiteIntArray* input_dims = input->dims;
int input_dims_size = input_dims->size;
TF_LITE_ENSURE(context, input_dims_size >= 1);
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
TfLiteTensor* output;
TF_LITE_ENSURE_OK(context,
GetOutputSafe(context, node, kOutputTensor, &output));
// Resize the output tensor.
TfLiteIntArray* output_shape = TfLiteIntArrayCreate(input_dims_size + 1);
for (int i = 0; i < input_dims_size; i++) {
@ -116,8 +119,11 @@ void FillDiagHelper(const TfLiteTensor* input, TfLiteTensor* output) {
}
TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
const TfLiteTensor* input = GetInput(context, node, kInputTensor);
TfLiteTensor* output;
TF_LITE_ENSURE_OK(context,
GetOutputSafe(context, node, kOutputTensor, &output));
const TfLiteTensor* input;
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputTensor, &input));
FillDiagHelper(input, output);
return kTfLiteOk;
}

View File

@ -33,12 +33,15 @@ constexpr int kOutputTensor = 0;
TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_ENSURE_EQ(context, NumInputs(node), 2);
TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
const TfLiteTensor* input = GetInput(context, node, kInputTensor);
const TfLiteTensor* input;
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputTensor, &input));
TfLiteIntArray* input_dims = input->dims;
int input_dims_size = input_dims->size;
TF_LITE_ENSURE(context, input_dims_size >= 2);
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
TfLiteTensor* output;
TF_LITE_ENSURE_OK(context,
GetOutputSafe(context, node, kOutputTensor, &output));
TfLiteIntArray* output_shape = TfLiteIntArrayCreate(input_dims_size);
for (int i = 0; i < input_dims_size; i++) {
@ -126,9 +129,14 @@ void FillDiagHelper(const TfLiteTensor* input, const TfLiteTensor* diag,
}
TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
const TfLiteTensor* input = GetInput(context, node, kInputTensor);
const TfLiteTensor* diag = GetInput(context, node, kDiagonalTensor);
TfLiteTensor* output;
TF_LITE_ENSURE_OK(context,
GetOutputSafe(context, node, kOutputTensor, &output));
const TfLiteTensor* input;
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputTensor, &input));
const TfLiteTensor* diag;
TF_LITE_ENSURE_OK(context,
GetInputSafe(context, node, kDiagonalTensor, &diag));
FillDiagHelper(input, diag, output);
return kTfLiteOk;
}

View File

@ -73,9 +73,15 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_ENSURE_EQ(context, NumInputs(node), 2);
TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
const TfLiteTensor* input_wav = GetInput(context, node, kInputTensorWav);
const TfLiteTensor* input_rate = GetInput(context, node, kInputTensorRate);
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
const TfLiteTensor* input_wav;
TF_LITE_ENSURE_OK(context,
GetInputSafe(context, node, kInputTensorWav, &input_wav));
const TfLiteTensor* input_rate;
TF_LITE_ENSURE_OK(context,
GetInputSafe(context, node, kInputTensorRate, &input_rate));
TfLiteTensor* output;
TF_LITE_ENSURE_OK(context,
GetOutputSafe(context, node, kOutputTensor, &output));
TF_LITE_ENSURE_EQ(context, NumDimensions(input_wav), 3);
TF_LITE_ENSURE_EQ(context, NumElements(input_rate), 1);
@ -101,9 +107,15 @@ template <KernelType kernel_type>
TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
auto* params = reinterpret_cast<TfLiteMfccParams*>(node->user_data);
const TfLiteTensor* input_wav = GetInput(context, node, kInputTensorWav);
const TfLiteTensor* input_rate = GetInput(context, node, kInputTensorRate);
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
const TfLiteTensor* input_wav;
TF_LITE_ENSURE_OK(context,
GetInputSafe(context, node, kInputTensorWav, &input_wav));
const TfLiteTensor* input_rate;
TF_LITE_ENSURE_OK(context,
GetInputSafe(context, node, kInputTensorRate, &input_rate));
TfLiteTensor* output;
TF_LITE_ENSURE_OK(context,
GetOutputSafe(context, node, kOutputTensor, &output));
const int32 sample_rate = *GetTensorData<int>(input_rate);

View File

@ -162,8 +162,10 @@ struct MirrorPadWorkerTask : cpu_backend_threadpool::Task {
TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
ruy::profiler::ScopeLabel label("MirrorPad");
const TfLiteTensor* input_tensor = GetInput(context, node, 0);
const TfLiteTensor* padding_matrix = GetInput(context, node, 1);
const TfLiteTensor* input_tensor;
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 0, &input_tensor));
const TfLiteTensor* padding_matrix;
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 1, &padding_matrix));
auto* params =
reinterpret_cast<TfLiteMirrorPaddingParams*>(node->builtin_data);
@ -172,7 +174,8 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
}
const int input_dims = NumDimensions(input_tensor);
TfLiteTensor* output_tensor = GetOutput(context, node, 0);
TfLiteTensor* output_tensor;
TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, 0, &output_tensor));
if (IsDynamicTensor(output_tensor)) {
auto output_size = GetPaddedOutputShape(input_tensor, padding_matrix);
if (output_size == nullptr) {
@ -258,9 +261,12 @@ void* Init(TfLiteContext* context, const char* buffer, size_t length) {
void Free(TfLiteContext* context, void* buffer) {}
TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
const TfLiteTensor* input_tensor = GetInput(context, node, 0);
const TfLiteTensor* padding_matrix = GetInput(context, node, 1);
TfLiteTensor* output_tensor = GetOutput(context, node, 0);
const TfLiteTensor* input_tensor;
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 0, &input_tensor));
const TfLiteTensor* padding_matrix;
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 1, &padding_matrix));
TfLiteTensor* output_tensor;
TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, 0, &output_tensor));
TF_LITE_ENSURE_EQ(context, NumDimensions(padding_matrix), 2);
TF_LITE_ENSURE_EQ(context, SizeOfDimension(padding_matrix, 0),

View File

@ -75,9 +75,15 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_ENSURE_EQ(context, NumInputs(node), 2);
TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
const TfLiteTensor* input1 = GetInput(context, node, kInputTensor1);
const TfLiteTensor* input2 = GetInput(context, node, kInputTensor2);
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
const TfLiteTensor* input1;
TF_LITE_ENSURE_OK(context,
GetInputSafe(context, node, kInputTensor1, &input1));
const TfLiteTensor* input2;
TF_LITE_ENSURE_OK(context,
GetInputSafe(context, node, kInputTensor2, &input2));
TfLiteTensor* output;
TF_LITE_ENSURE_OK(context,
GetOutputSafe(context, node, kOutputTensor, &output));
TF_LITE_ENSURE_TYPES_EQ(context, input1->type, input2->type);
@ -259,9 +265,15 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
auto* params = reinterpret_cast<TfLiteMulParams*>(node->builtin_data);
OpData* data = reinterpret_cast<OpData*>(node->user_data);
const TfLiteTensor* input1 = GetInput(context, node, kInputTensor1);
const TfLiteTensor* input2 = GetInput(context, node, kInputTensor2);
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
const TfLiteTensor* input1;
TF_LITE_ENSURE_OK(context,
GetInputSafe(context, node, kInputTensor1, &input1));
const TfLiteTensor* input2;
TF_LITE_ENSURE_OK(context,
GetInputSafe(context, node, kInputTensor2, &input2));
TfLiteTensor* output;
TF_LITE_ENSURE_OK(context,
GetOutputSafe(context, node, kOutputTensor, &output));
if (output->type == kTfLiteFloat32 || output->type == kTfLiteInt32) {
EvalMul<kernel_type>(context, node, params, data, input1, input2, output);

View File

@ -34,8 +34,11 @@ constexpr int kOutputTensor = 0;
TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_ENSURE_EQ(context, NumInputs(node), 1);
TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
const TfLiteTensor* input = GetInput(context, node, kInputTensor);
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
const TfLiteTensor* input;
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputTensor, &input));
TfLiteTensor* output;
TF_LITE_ENSURE_OK(context,
GetOutputSafe(context, node, kOutputTensor, &output));
output->type = input->type;
return context->ResizeTensor(context, output,
@ -43,8 +46,11 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
}
TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
const TfLiteTensor* input = GetInput(context, node, kInputTensor);
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
const TfLiteTensor* input;
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputTensor, &input));
TfLiteTensor* output;
TF_LITE_ENSURE_OK(context,
GetOutputSafe(context, node, kOutputTensor, &output));
switch (input->type) {
case kTfLiteInt64:
reference_ops::Negate(

View File

@ -79,20 +79,25 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
}
// Boxes & Scores.
const TfLiteTensor* input_boxes = GetInput(context, node, kInputTensorBoxes);
const TfLiteTensor* input_boxes;
TF_LITE_ENSURE_OK(
context, GetInputSafe(context, node, kInputTensorBoxes, &input_boxes));
TF_LITE_ENSURE_EQ(context, input_boxes->type, kTfLiteFloat32);
TF_LITE_ENSURE_EQ(context, NumDimensions(input_boxes), 2);
TF_LITE_ENSURE_EQ(context, SizeOfDimension(input_boxes, 1), 4);
const int num_boxes = SizeOfDimension(input_boxes, 0);
const TfLiteTensor* input_scores =
GetInput(context, node, kInputTensorScores);
const TfLiteTensor* input_scores;
TF_LITE_ENSURE_OK(
context, GetInputSafe(context, node, kInputTensorScores, &input_scores));
TF_LITE_ENSURE_EQ(context, input_scores->type, kTfLiteFloat32);
TF_LITE_ENSURE_EQ(context, NumDimensions(input_scores), 1);
TF_LITE_ENSURE_EQ(context, num_boxes, SizeOfDimension(input_scores, 0));
// Max output size.
const TfLiteTensor* input_max_output_size =
GetInput(context, node, kInputTensorMaxOutputSize);
const TfLiteTensor* input_max_output_size;
TF_LITE_ENSURE_OK(context,
GetInputSafe(context, node, kInputTensorMaxOutputSize,
&input_max_output_size));
TF_LITE_ENSURE_EQ(context, input_max_output_size->type, kTfLiteInt32);
TF_LITE_ENSURE_EQ(context, NumDimensions(input_max_output_size), 0);
const bool is_max_output_size_const = IsConstantTensor(input_max_output_size);
@ -103,30 +108,43 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
}
// IoU & Score thresholds.
const TfLiteTensor* input_iou_threshold =
GetInput(context, node, kInputTensorIouThreshold);
const TfLiteTensor* input_iou_threshold;
TF_LITE_ENSURE_OK(context,
GetInputSafe(context, node, kInputTensorIouThreshold,
&input_iou_threshold));
TF_LITE_ENSURE_EQ(context, input_iou_threshold->type, kTfLiteFloat32);
TF_LITE_ENSURE_EQ(context, NumDimensions(input_iou_threshold), 0);
const TfLiteTensor* input_score_threshold =
GetInput(context, node, kInputTensorScoreThreshold);
const TfLiteTensor* input_score_threshold;
TF_LITE_ENSURE_OK(context,
GetInputSafe(context, node, kInputTensorScoreThreshold,
&input_score_threshold));
TF_LITE_ENSURE_EQ(context, input_iou_threshold->type, kTfLiteFloat32);
TF_LITE_ENSURE_EQ(context, NumDimensions(input_score_threshold), 0);
if (is_soft_nms) {
const TfLiteTensor* input_sigma =
GetInput(context, node, kInputTensorSigma);
const TfLiteTensor* input_sigma;
TF_LITE_ENSURE_OK(
context, GetInputSafe(context, node, kInputTensorSigma, &input_sigma));
TF_LITE_ENSURE_EQ(context, input_sigma->type, kTfLiteFloat32);
TF_LITE_ENSURE_EQ(context, NumDimensions(input_sigma), 0);
TF_LITE_ENSURE_EQ(context, NumOutputs(node), 3);
TfLiteTensor* output_selected_indices =
GetOutput(context, node, kSoftNMSOutputTensorSelectedIndices);
TfLiteTensor* output_selected_indices;
TF_LITE_ENSURE_OK(
context,
GetOutputSafe(context, node, kSoftNMSOutputTensorSelectedIndices,
&output_selected_indices));
output_selected_indices->type = kTfLiteInt32;
TfLiteTensor* output_selected_scores =
GetOutput(context, node, kSoftNMSOutputTensorSelectedScores);
TfLiteTensor* output_selected_scores;
TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node,
kSoftNMSOutputTensorSelectedScores,
&output_selected_scores));
output_selected_scores->type = kTfLiteFloat32;
TfLiteTensor* output_num_selected_indices =
GetOutput(context, node, kSoftNMSOutputTensorNumSelectedIndices);
TfLiteTensor* output_num_selected_indices;
TF_LITE_ENSURE_OK(
context,
GetOutputSafe(context, node, kSoftNMSOutputTensorNumSelectedIndices,
&output_num_selected_indices));
output_num_selected_indices->type = kTfLiteInt32;
SetTensorSizes(context, output_num_selected_indices, {});
@ -139,11 +157,15 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
}
} else {
TF_LITE_ENSURE_EQ(context, NumOutputs(node), 2);
TfLiteTensor* output_selected_indices =
GetOutput(context, node, kNMSOutputTensorSelectedIndices);
TfLiteTensor* output_selected_indices;
TF_LITE_ENSURE_OK(
context, GetOutputSafe(context, node, kNMSOutputTensorSelectedIndices,
&output_selected_indices));
output_selected_indices->type = kTfLiteInt32;
TfLiteTensor* output_num_selected_indices =
GetOutput(context, node, kNMSOutputTensorNumSelectedIndices);
TfLiteTensor* output_num_selected_indices;
TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node,
kNMSOutputTensorNumSelectedIndices,
&output_num_selected_indices));
output_num_selected_indices->type = kTfLiteInt32;
SetTensorSizes(context, output_num_selected_indices, {});
@ -179,20 +201,29 @@ void ResetUnusedElementsToZeroes(const int max_output_size,
TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
const bool is_soft_nms = NumInputs(node) == 6;
const TfLiteTensor* input_boxes = GetInput(context, node, kInputTensorBoxes);
const TfLiteTensor* input_boxes;
TF_LITE_ENSURE_OK(
context, GetInputSafe(context, node, kInputTensorBoxes, &input_boxes));
const int num_boxes = SizeOfDimension(input_boxes, 0);
const TfLiteTensor* input_scores =
GetInput(context, node, kInputTensorScores);
const TfLiteTensor* input_max_output_size =
GetInput(context, node, kInputTensorMaxOutputSize);
const TfLiteTensor* input_scores;
TF_LITE_ENSURE_OK(
context, GetInputSafe(context, node, kInputTensorScores, &input_scores));
const TfLiteTensor* input_max_output_size;
TF_LITE_ENSURE_OK(context,
GetInputSafe(context, node, kInputTensorMaxOutputSize,
&input_max_output_size));
const int max_output_size_value = *GetTensorData<int>(input_max_output_size);
TF_LITE_ENSURE(context, (max_output_size_value >= 0));
const bool is_max_output_size_const = IsConstantTensor(input_max_output_size);
const TfLiteTensor* input_iou_threshold =
GetInput(context, node, kInputTensorIouThreshold);
const TfLiteTensor* input_iou_threshold;
TF_LITE_ENSURE_OK(context,
GetInputSafe(context, node, kInputTensorIouThreshold,
&input_iou_threshold));
const float iou_threshold = *GetTensorData<float>(input_iou_threshold);
const TfLiteTensor* input_score_threshold =
GetInput(context, node, kInputTensorScoreThreshold);
const TfLiteTensor* input_score_threshold;
TF_LITE_ENSURE_OK(context,
GetInputSafe(context, node, kInputTensorScoreThreshold,
&input_score_threshold));
const float score_threshold = *GetTensorData<float>(input_score_threshold);
TfLiteTensor* output_selected_indices = nullptr;
@ -200,8 +231,9 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
TfLiteTensor* output_num_selected_indices = nullptr;
if (is_soft_nms) {
const TfLiteTensor* input_sigma =
GetInput(context, node, kInputTensorSigma);
const TfLiteTensor* input_sigma;
TF_LITE_ENSURE_OK(
context, GetInputSafe(context, node, kInputTensorSigma, &input_sigma));
const float soft_nms_sigma = *GetTensorData<float>(input_sigma);
if (soft_nms_sigma < 0) {
context->ReportError(context, "Invalid sigma value for soft NMS: %f",
@ -209,12 +241,17 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
return kTfLiteError;
}
output_selected_indices =
GetOutput(context, node, kSoftNMSOutputTensorSelectedIndices);
output_selected_scores =
GetOutput(context, node, kSoftNMSOutputTensorSelectedScores);
output_num_selected_indices =
GetOutput(context, node, kSoftNMSOutputTensorNumSelectedIndices);
TF_LITE_ENSURE_OK(
context,
GetOutputSafe(context, node, kSoftNMSOutputTensorSelectedIndices,
&output_selected_indices));
TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node,
kSoftNMSOutputTensorSelectedScores,
&output_selected_scores));
TF_LITE_ENSURE_OK(
context,
GetOutputSafe(context, node, kSoftNMSOutputTensorNumSelectedIndices,
&output_num_selected_indices));
if (!is_max_output_size_const) {
SetTensorSizes(context, output_selected_indices, {max_output_size_value});
SetTensorSizes(context, output_selected_scores, {max_output_size_value});
@ -228,10 +265,12 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
max_output_size_value, *output_num_selected_indices->data.i32,
output_selected_indices->data.i32, output_selected_scores->data.f);
} else {
output_selected_indices =
GetOutput(context, node, kNMSOutputTensorSelectedIndices);
output_num_selected_indices =
GetOutput(context, node, kNMSOutputTensorNumSelectedIndices);
TF_LITE_ENSURE_OK(
context, GetOutputSafe(context, node, kNMSOutputTensorSelectedIndices,
&output_selected_indices));
TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node,
kNMSOutputTensorNumSelectedIndices,
&output_num_selected_indices));
if (!is_max_output_size_const) {
SetTensorSizes(context, output_selected_indices, {max_output_size_value});
}

View File

@ -109,7 +109,9 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
node->temporaries = TfLiteIntArrayCreate(1);
node->temporaries->data[0] = op_data->cache_tensor_id;
TfLiteTensor* dequantized = GetTemporary(context, node, /*index=*/0);
TfLiteTensor* dequantized;
TF_LITE_ENSURE_OK(context,
GetTemporarySafe(context, node, /*index=*/0, &dequantized));
dequantized->type = op_context.ref->type;
dequantized->allocation_type = kTfLiteDynamic;
@ -142,7 +144,9 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
}
// Dequantize the input
TfLiteTensor* dequantized = GetTemporary(context, node, /*index=*/0);
TfLiteTensor* dequantized;
TF_LITE_ENSURE_OK(context,
GetTemporarySafe(context, node, /*index=*/0, &dequantized));
auto status = builtin::dequantize::DequantizeImpl<kernel_type>(
context, node, op_context.input, dequantized);
if (status != kTfLiteOk) {

View File

@ -38,7 +38,8 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_ENSURE_EQ(context, NumInputs(node), data->values_count);
TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
const TfLiteTensor* input0 = GetInput(context, node, 0);
const TfLiteTensor* input0;
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 0, &input0));
const int dimension_size = NumDimensions(input0) + 1;
if (data->axis < 0) {
data->axis += dimension_size;
@ -55,7 +56,8 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
}
// Make sure all inputs have the same shape and type.
for (int i = 1; i < data->values_count; ++i) {
const TfLiteTensor* input = GetInput(context, node, i);
const TfLiteTensor* input;
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, i, &input));
TF_LITE_ENSURE(context, HaveSameShapes(input0, input));
TF_LITE_ENSURE_TYPES_EQ(context, input0->type, input->type);
}
@ -72,13 +74,16 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
}
}
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
TfLiteTensor* output;
TF_LITE_ENSURE_OK(context,
GetOutputSafe(context, node, kOutputTensor, &output));
TF_LITE_ENSURE_TYPES_EQ(context, output->type, input0->type);
// Guarantee input/output quantization params match as we do not support
// packing quantized tensors.
for (int i = 0; i < data->values_count; i++) {
const TfLiteTensor* input = GetInput(context, node, i);
const TfLiteTensor* input;
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, i, &input));
TF_LITE_ENSURE_EQ(context, input->params.zero_point,
output->params.zero_point);
TF_LITE_ENSURE_EQ(context, input->params.scale, output->params.scale);
@ -106,7 +111,9 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
const TfLitePackParams* data =
reinterpret_cast<TfLitePackParams*>(node->builtin_data);
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
TfLiteTensor* output;
TF_LITE_ENSURE_OK(context,
GetOutputSafe(context, node, kOutputTensor, &output));
switch (output->type) {
case kTfLiteFloat32: {
return PackImpl<float>(context, node, output, data->values_count,

View File

@ -71,8 +71,10 @@ TfLiteStatus GenericPrepare(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_ENSURE_EQ(context, NumInputs(node), 1);
TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
TfLiteTensor* output = GetOutput(context, node, 0);
const TfLiteTensor* input = GetInput(context, node, 0);
TfLiteTensor* output;
TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, 0, &output));
const TfLiteTensor* input;
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 0, &input));
TF_LITE_ENSURE_EQ(context, NumDimensions(input), 4);
TF_LITE_ENSURE_TYPES_EQ(context, input->type, output->type);
@ -368,8 +370,10 @@ TfLiteStatus AverageEval(TfLiteContext* context, TfLiteNode* node) {
auto* params = reinterpret_cast<TfLitePoolParams*>(node->builtin_data);
OpData* data = reinterpret_cast<OpData*>(node->user_data);
TfLiteTensor* output = GetOutput(context, node, 0);
const TfLiteTensor* input = GetInput(context, node, 0);
TfLiteTensor* output;
TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, 0, &output));
const TfLiteTensor* input;
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 0, &input));
switch (input->type) { // Already know in/out types are same.
case kTfLiteFloat32:
AverageEvalFloat<kernel_type>(context, node, params, data, input, output);
@ -399,8 +403,10 @@ TfLiteStatus MaxEval(TfLiteContext* context, TfLiteNode* node) {
auto* params = reinterpret_cast<TfLitePoolParams*>(node->builtin_data);
OpData* data = reinterpret_cast<OpData*>(node->user_data);
TfLiteTensor* output = GetOutput(context, node, 0);
const TfLiteTensor* input = GetInput(context, node, 0);
TfLiteTensor* output;
TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, 0, &output));
const TfLiteTensor* input;
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 0, &input));
switch (input->type) { // Already know in/out types are same.
case kTfLiteFloat32:
MaxEvalFloat<kernel_type>(context, node, params, data, input, output);
@ -430,8 +436,10 @@ TfLiteStatus L2Eval(TfLiteContext* context, TfLiteNode* node) {
auto* params = reinterpret_cast<TfLitePoolParams*>(node->builtin_data);
OpData* data = reinterpret_cast<OpData*>(node->user_data);
TfLiteTensor* output = GetOutput(context, node, 0);
const TfLiteTensor* input = GetInput(context, node, 0);
TfLiteTensor* output;
TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, 0, &output));
const TfLiteTensor* input;
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 0, &input));
switch (input->type) { // Already know in/out types are same.
case kTfLiteFloat32:
L2EvalFloat<kernel_type>(context, node, params, data, input, output);

View File

@ -54,9 +54,15 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
OpData* data = reinterpret_cast<OpData*>(node->user_data);
const TfLiteTensor* input1 = GetInput(context, node, kInputTensor1);
const TfLiteTensor* input2 = GetInput(context, node, kInputTensor2);
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
const TfLiteTensor* input1;
TF_LITE_ENSURE_OK(context,
GetInputSafe(context, node, kInputTensor1, &input1));
const TfLiteTensor* input2;
TF_LITE_ENSURE_OK(context,
GetInputSafe(context, node, kInputTensor2, &input2));
TfLiteTensor* output;
TF_LITE_ENSURE_OK(context,
GetOutputSafe(context, node, kOutputTensor, &output));
TF_LITE_ENSURE_TYPES_EQ(context, input1->type, input2->type);
@ -112,9 +118,15 @@ TfLiteStatus CheckValue(TfLiteContext* context, const TfLiteTensor* input) {
TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
OpData* data = reinterpret_cast<OpData*>(node->user_data);
const TfLiteTensor* input1 = GetInput(context, node, kInputTensor1);
const TfLiteTensor* input2 = GetInput(context, node, kInputTensor2);
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
const TfLiteTensor* input1;
TF_LITE_ENSURE_OK(context,
GetInputSafe(context, node, kInputTensor1, &input1));
const TfLiteTensor* input2;
TF_LITE_ENSURE_OK(context,
GetInputSafe(context, node, kInputTensor2, &input2));
TfLiteTensor* output;
TF_LITE_ENSURE_OK(context,
GetOutputSafe(context, node, kOutputTensor, &output));
switch (output->type) {
case kTfLiteInt32: {

View File

@ -97,8 +97,10 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_ENSURE_EQ(context, NumInputs(node), 1);
TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
const TfLiteTensor* input = GetInput(context, node, 0);
TfLiteTensor* output = GetOutput(context, node, 0);
const TfLiteTensor* input;
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 0, &input));
TfLiteTensor* output;
TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, 0, &output));
// TODO(b/128934713): Add support for fixed-point per-channel quantization.
// Currently this only support affine per-layer quantization.
@ -141,8 +143,10 @@ template <KernelType kernel_type>
TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
OpData* data = static_cast<OpData*>(node->user_data);
const TfLiteTensor* input = GetInput(context, node, 0);
TfLiteTensor* output = GetOutput(context, node, 0);
const TfLiteTensor* input;
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 0, &input));
TfLiteTensor* output;
TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, 0, &output));
const RuntimeShape input_shape = GetTensorShape(input);
const RuntimeShape output_shape = GetTensorShape(output);

View File

@ -83,9 +83,12 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_ENSURE_EQ(context, NumInputs(node), 3);
TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
const TfLiteTensor* start = GetInput(context, node, kStartTensor);
const TfLiteTensor* limit = GetInput(context, node, kLimitTensor);
const TfLiteTensor* delta = GetInput(context, node, kDeltaTensor);
const TfLiteTensor* start;
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kStartTensor, &start));
const TfLiteTensor* limit;
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kLimitTensor, &limit));
const TfLiteTensor* delta;
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kDeltaTensor, &delta));
// Make sure all the inputs are scalars.
TF_LITE_ENSURE_EQ(context, NumDimensions(start), 0);
TF_LITE_ENSURE_EQ(context, NumDimensions(limit), 0);
@ -103,7 +106,9 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_ENSURE_TYPES_EQ(context, limit->type, dtype);
TF_LITE_ENSURE_TYPES_EQ(context, delta->type, dtype);
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
TfLiteTensor* output;
TF_LITE_ENSURE_OK(context,
GetOutputSafe(context, node, kOutputTensor, &output));
output->type = dtype;
if (IsConstantTensor(start) && IsConstantTensor(limit) &&
@ -130,11 +135,16 @@ void EvalImpl(const TfLiteTensor* start, const TfLiteTensor* delta,
}
TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
const TfLiteTensor* start = GetInput(context, node, kStartTensor);
const TfLiteTensor* limit = GetInput(context, node, kLimitTensor);
const TfLiteTensor* delta = GetInput(context, node, kDeltaTensor);
const TfLiteTensor* start;
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kStartTensor, &start));
const TfLiteTensor* limit;
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kLimitTensor, &limit));
const TfLiteTensor* delta;
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kDeltaTensor, &delta));
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
TfLiteTensor* output;
TF_LITE_ENSURE_OK(context,
GetOutputSafe(context, node, kOutputTensor, &output));
if (IsDynamicTensor(output)) {
TF_LITE_ENSURE_OK(context,

View File

@ -31,8 +31,11 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_ENSURE_EQ(context, NumInputs(node), 1);
TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
const TfLiteTensor* input = GetInput(context, node, kInputTensor);
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
const TfLiteTensor* input;
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputTensor, &input));
TfLiteTensor* output;
TF_LITE_ENSURE_OK(context,
GetOutputSafe(context, node, kOutputTensor, &output));
output->type = kTfLiteInt32;
// By design, the input shape is always known at the time of Prepare, even

View File

@ -34,12 +34,15 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_ENSURE_EQ(context, node->inputs->size, 1);
TF_LITE_ENSURE_EQ(context, node->outputs->size, 1);
const TfLiteTensor* input_resource_id_tensor =
GetInput(context, node, kInputVariableId);
const TfLiteTensor* input_resource_id_tensor;
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputVariableId,
&input_resource_id_tensor));
TF_LITE_ENSURE_EQ(context, input_resource_id_tensor->type, kTfLiteInt32);
TF_LITE_ENSURE_EQ(context, NumElements(input_resource_id_tensor), 1);
TfLiteTensor* output = GetOutput(context, node, kOutputValue);
TfLiteTensor* output;
TF_LITE_ENSURE_OK(context,
GetOutputSafe(context, node, kOutputValue, &output));
SetTensorToDynamic(output);
return kTfLiteOk;
@ -48,15 +51,18 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
Subgraph* subgraph = reinterpret_cast<Subgraph*>(context->impl_);
const TfLiteTensor* input_resource_id_tensor =
GetInput(context, node, kInputVariableId);
const TfLiteTensor* input_resource_id_tensor;
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputVariableId,
&input_resource_id_tensor));
int resource_id = input_resource_id_tensor->data.i32[0];
auto& resources = subgraph->resources();
auto* variable = resource::GetResourceVariable(&resources, resource_id);
TF_LITE_ENSURE(context, variable != nullptr);
TfLiteTensor* variable_tensor = variable->GetTensor();
TfLiteTensor* output = GetOutput(context, node, kOutputValue);
TfLiteTensor* output;
TF_LITE_ENSURE_OK(context,
GetOutputSafe(context, node, kOutputValue, &output));
TF_LITE_ENSURE_TYPES_EQ(context, variable_tensor->type, output->type);
TF_LITE_ENSURE_OK(

View File

@ -170,7 +170,9 @@ TfLiteStatus InitializeTemporaries(TfLiteContext* context, TfLiteNode* node,
TfLiteIntArrayFree(node->temporaries);
node->temporaries = TfLiteIntArrayCreate(3);
node->temporaries->data[0] = op_data->scratch_tensor_index;
TfLiteTensor* scratch_tensor = GetTemporary(context, node, /*index=*/0);
TfLiteTensor* scratch_tensor;
TF_LITE_ENSURE_OK(
context, GetTemporarySafe(context, node, /*index=*/0, &scratch_tensor));
scratch_tensor->type = kTfLiteInt32;
scratch_tensor->allocation_type = kTfLiteArenaRw;
TfLiteIntArray* index_size = TfLiteIntArrayCreate(1);
@ -180,11 +182,15 @@ TfLiteStatus InitializeTemporaries(TfLiteContext* context, TfLiteNode* node,
// Creates a temp tensor to store resolved axis given input data.
node->temporaries->data[1] = op_data->scratch_tensor_index + 1;
TfLiteTensor* resolved_axis = GetTemporary(context, node, /*index=*/1);
TfLiteTensor* resolved_axis;
TF_LITE_ENSURE_OK(
context, GetTemporarySafe(context, node, /*index=*/1, &resolved_axis));
resolved_axis->type = kTfLiteInt32;
// Creates a temp tensor to store temp sums when calculating mean.
node->temporaries->data[2] = op_data->scratch_tensor_index + 2;
TfLiteTensor* temp_sum = GetTemporary(context, node, /*index=*/2);
TfLiteTensor* temp_sum;
TF_LITE_ENSURE_OK(context,
GetTemporarySafe(context, node, /*index=*/2, &temp_sum));
switch (op_context->input->type) {
case kTfLiteFloat32:
temp_sum->type = kTfLiteFloat32;
@ -217,7 +223,9 @@ TfLiteStatus PrepareSimple(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_ENSURE_TYPES_EQ(context, op_context.axis->type, kTfLiteInt32);
TF_LITE_ENSURE_OK(context, InitializeTemporaries(context, node, &op_context));
TfLiteTensor* resolved_axis = GetTemporary(context, node, /*index=*/1);
TfLiteTensor* resolved_axis;
TF_LITE_ENSURE_OK(
context, GetTemporarySafe(context, node, /*index=*/1, &resolved_axis));
// Leaves work to Eval if axis is not constant; else resizes output.
if (!IsConstantTensor(op_context.axis)) {
SetTensorToDynamic(op_context.output);
@ -233,7 +241,8 @@ TfLiteStatus PrepareSimple(TfLiteContext* context, TfLiteNode* node) {
TfLiteStatus PrepareAny(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_ENSURE_EQ(context, NumInputs(node), 2);
const TfLiteTensor* input = GetInput(context, node, 0);
const TfLiteTensor* input;
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 0, &input));
TF_LITE_ENSURE_TYPES_EQ(context, input->type, kTfLiteBool);
return PrepareSimple(context, node);
}
@ -254,7 +263,9 @@ TfLiteStatus PrepareMeanOrSum(TfLiteContext* context, TfLiteNode* node) {
QuantizeMultiplier(real_multiplier, &data->multiplier, &exponent);
data->shift = exponent;
}
TfLiteTensor* temp_sum = GetTemporary(context, node, /*index=*/2);
TfLiteTensor* temp_sum;
TF_LITE_ENSURE_OK(context,
GetTemporarySafe(context, node, /*index=*/2, &temp_sum));
if (!IsConstantTensor(op_context.axis)) {
SetTensorToDynamic(temp_sum);
return kTfLiteOk;
@ -343,9 +354,15 @@ TfLiteStatus EvalMean(TfLiteContext* context, TfLiteNode* node) {
OpData* data = reinterpret_cast<OpData*>(node->user_data);
int num_axis = static_cast<int>(NumElements(op_context.axis));
TfLiteTensor* temp_index = GetTemporary(context, node, /*index=*/0);
TfLiteTensor* resolved_axis = GetTemporary(context, node, /*index=*/1);
TfLiteTensor* temp_sum = GetTemporary(context, node, /*index=*/2);
TfLiteTensor* temp_index;
TF_LITE_ENSURE_OK(context,
GetTemporarySafe(context, node, /*index=*/0, &temp_index));
TfLiteTensor* resolved_axis;
TF_LITE_ENSURE_OK(
context, GetTemporarySafe(context, node, /*index=*/1, &resolved_axis));
TfLiteTensor* temp_sum;
TF_LITE_ENSURE_OK(context,
GetTemporarySafe(context, node, /*index=*/2, &temp_sum));
// Resize the output tensor if the output tensor is dynamic.
if (IsDynamicTensor(op_context.output)) {
TF_LITE_ENSURE_OK(context,
@ -490,8 +507,12 @@ TfLiteStatus EvalLogic(TfLiteContext* context, TfLiteNode* node,
OpContext* op_context, T init_value,
T reducer(const T current, const T in)) {
int64_t num_axis = NumElements(op_context->axis);
TfLiteTensor* temp_index = GetTemporary(context, node, /*index=*/0);
TfLiteTensor* resolved_axis = GetTemporary(context, node, /*index=*/1);
TfLiteTensor* temp_index;
TF_LITE_ENSURE_OK(context,
GetTemporarySafe(context, node, /*index=*/0, &temp_index));
TfLiteTensor* resolved_axis;
TF_LITE_ENSURE_OK(
context, GetTemporarySafe(context, node, /*index=*/1, &resolved_axis));
// Resize the output tensor if the output tensor is dynamic.
if (IsDynamicTensor(op_context->output)) {
TF_LITE_ENSURE_OK(context,
@ -621,9 +642,15 @@ TfLiteStatus EvalSum(TfLiteContext* context, TfLiteNode* node) {
if (need_rescale) {
// Rescaling 8bit reduce sum.
int num_axis = static_cast<int>(NumElements(op_context.axis));
TfLiteTensor* temp_index = GetTemporary(context, node, /*index=*/0);
TfLiteTensor* resolved_axis = GetTemporary(context, node, /*index=*/1);
TfLiteTensor* temp_sum = GetTemporary(context, node, /*index=*/2);
TfLiteTensor* temp_index;
TF_LITE_ENSURE_OK(
context, GetTemporarySafe(context, node, /*index=*/0, &temp_index));
TfLiteTensor* resolved_axis;
TF_LITE_ENSURE_OK(
context, GetTemporarySafe(context, node, /*index=*/1, &resolved_axis));
TfLiteTensor* temp_sum;
TF_LITE_ENSURE_OK(context,
GetTemporarySafe(context, node, /*index=*/2, &temp_sum));
// Resize the output tensor if the output tensor is dynamic.
if (IsDynamicTensor(op_context.output)) {
TF_LITE_ENSURE_OK(context,

View File

@ -38,8 +38,11 @@ TfLiteStatus ResizeOutput(TfLiteContext* context, TfLiteNode* node) {
std::unique_ptr<TfLiteIntArray, void (*)(TfLiteIntArray*)>
scoped_output_shape(output_shape, TfLiteIntArrayFree);
const TfLiteTensor* input = GetInput(context, node, kInputTensor);
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
const TfLiteTensor* input;
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputTensor, &input));
TfLiteTensor* output;
TF_LITE_ENSURE_OK(context,
GetOutputSafe(context, node, kOutputTensor, &output));
// Tensorflow's Reshape allows one of the shape components to have the
// special -1 value, meaning it will be calculated automatically based on the
@ -70,6 +73,7 @@ TfLiteStatus ResizeOutput(TfLiteContext* context, TfLiteNode* node) {
inline TfLiteIntArray* GetOutputShapeFromTensor(TfLiteContext* context,
TfLiteNode* node) {
const TfLiteTensor* shape = GetInput(context, node, kShapeTensor);
if (shape == nullptr) return nullptr;
TfLiteIntArray* output_shape = TfLiteIntArrayCreate(shape->dims->data[0]);
for (int i = 0; i < output_shape->size; ++i) {
@ -103,7 +107,8 @@ inline TfLiteIntArray* GetOutputShapeFromParam(TfLiteContext* context,
// Check if the shape tensor is valid. Shapes should be int32 vectors.
inline bool ShapeIsVector(TfLiteContext* context, TfLiteNode* node) {
const TfLiteTensor* shape = GetInput(context, node, kShapeTensor);
return (shape->dims->size == 1 && shape->type == kTfLiteInt32);
return (shape != nullptr && shape->dims->size == 1 &&
shape->type == kTfLiteInt32);
}
TfLiteIntArray* GetOutputShape(TfLiteContext* context, TfLiteNode* node) {
@ -122,7 +127,9 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
// calculate their shapes now. String tensors don't benefit from having their
// shapes precalculated because the actual memory can only be allocated after
// we know all the content.
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
TfLiteTensor* output;
TF_LITE_ENSURE_OK(context,
GetOutputSafe(context, node, kOutputTensor, &output));
if (output->type != kTfLiteString) {
if (NumInputs(node) == 1 ||
IsConstantTensor(GetInput(context, node, kShapeTensor))) {
@ -135,8 +142,11 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
}
TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
const TfLiteTensor* input = GetInput(context, node, kInputTensor);
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
const TfLiteTensor* input;
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputTensor, &input));
TfLiteTensor* output;
TF_LITE_ENSURE_OK(context,
GetOutputSafe(context, node, kOutputTensor, &output));
// There are two ways in which the 'output' can be made dynamic: it could be
// a string tensor, or its shape cannot be calculated during Prepare(). In

View File

@ -61,9 +61,13 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_ENSURE_EQ(context, NumInputs(node), 2);
TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
const TfLiteTensor* input = GetInput(context, node, kInputTensor);
const TfLiteTensor* size = GetInput(context, node, kSizeTensor);
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
const TfLiteTensor* input;
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputTensor, &input));
const TfLiteTensor* size;
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kSizeTensor, &size));
TfLiteTensor* output;
TF_LITE_ENSURE_OK(context,
GetOutputSafe(context, node, kOutputTensor, &output));
// TODO(ahentz): Our current implementations rely on the inputs being 4D.
TF_LITE_ENSURE_EQ(context, NumDimensions(input), 4);
@ -96,9 +100,13 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
auto* params =
reinterpret_cast<TfLiteResizeBilinearParams*>(node->builtin_data);
const TfLiteTensor* input = GetInput(context, node, kInputTensor);
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
const TfLiteTensor* size = GetInput(context, node, kSizeTensor);
const TfLiteTensor* input;
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputTensor, &input));
TfLiteTensor* output;
TF_LITE_ENSURE_OK(context,
GetOutputSafe(context, node, kOutputTensor, &output));
const TfLiteTensor* size;
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kSizeTensor, &size));
if (IsDynamicTensor(output)) {
TF_LITE_ENSURE_OK(context,

View File

@ -60,9 +60,13 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_ENSURE_EQ(context, NumInputs(node), 2);
TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
const TfLiteTensor* input = GetInput(context, node, kInputTensor);
const TfLiteTensor* size = GetInput(context, node, kSizeTensor);
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
const TfLiteTensor* input;
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputTensor, &input));
const TfLiteTensor* size;
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kSizeTensor, &size));
TfLiteTensor* output;
TF_LITE_ENSURE_OK(context,
GetOutputSafe(context, node, kOutputTensor, &output));
// TODO(ahentz): Our current implementations rely on the input being 4D,
// and the size being 1D tensor with exactly 2 elements.
@ -85,9 +89,13 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
auto* params =
reinterpret_cast<TfLiteResizeNearestNeighborParams*>(node->builtin_data);
const TfLiteTensor* input = GetInput(context, node, kInputTensor);
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
const TfLiteTensor* size = GetInput(context, node, kSizeTensor);
const TfLiteTensor* input;
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputTensor, &input));
TfLiteTensor* output;
TF_LITE_ENSURE_OK(context,
GetOutputSafe(context, node, kOutputTensor, &output));
const TfLiteTensor* size;
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kSizeTensor, &size));
if (IsDynamicTensor(output)) {
TF_LITE_ENSURE_OK(context,

View File

@ -35,8 +35,10 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_ENSURE_EQ(context, NumInputs(node), 2);
TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
const TfLiteTensor* input = GetInput(context, node, kInputTensor);
const TfLiteTensor* axis = GetInput(context, node, kAxisTensor);
const TfLiteTensor* input;
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputTensor, &input));
const TfLiteTensor* axis;
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kAxisTensor, &axis));
TF_LITE_ENSURE_EQ(context, NumDimensions(axis), 1);
TF_LITE_ENSURE(context, NumDimensions(input) >= NumElements(axis));
@ -59,7 +61,9 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
context->ReportError(context, "Current does not support more than 1 axis.");
}
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
TfLiteTensor* output;
TF_LITE_ENSURE_OK(context,
GetOutputSafe(context, node, kOutputTensor, &output));
TfLiteIntArray* output_shape = TfLiteIntArrayCopy(input->dims);
TF_LITE_ENSURE_TYPES_EQ(context, output->type, input->type);
@ -67,8 +71,11 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
}
TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
const TfLiteTensor* input = GetInput(context, node, kInputTensor);
const TfLiteTensor* axis_tensor = GetInput(context, node, kAxisTensor);
const TfLiteTensor* input;
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputTensor, &input));
const TfLiteTensor* axis_tensor;
TF_LITE_ENSURE_OK(context,
GetInputSafe(context, node, kAxisTensor, &axis_tensor));
int axis = GetTensorData<int32_t>(axis_tensor)[0];
const int rank = NumDimensions(input);
if (axis < 0) {
@ -76,7 +83,9 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
}
TF_LITE_ENSURE(context, axis >= 0 && axis < rank);
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
TfLiteTensor* output;
TF_LITE_ENSURE_OK(context,
GetOutputSafe(context, node, kOutputTensor, &output));
switch (output->type) {
case kTfLiteFloat32: {

View File

@ -36,8 +36,11 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_ENSURE_EQ(context, NumInputs(node), 2);
TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
const TfLiteTensor* input = GetInput(context, node, kInputTensor);
const TfLiteTensor* seq_lengths = GetInput(context, node, kSeqLengthsTensor);
const TfLiteTensor* input;
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputTensor, &input));
const TfLiteTensor* seq_lengths;
TF_LITE_ENSURE_OK(
context, GetInputSafe(context, node, kSeqLengthsTensor, &seq_lengths));
TF_LITE_ENSURE_EQ(context, NumDimensions(seq_lengths), 1);
if (input->type != kTfLiteInt32 && input->type != kTfLiteFloat32 &&
@ -56,7 +59,9 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
return kTfLiteError;
}
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
TfLiteTensor* output;
TF_LITE_ENSURE_OK(context,
GetOutputSafe(context, node, kOutputTensor, &output));
TfLiteIntArray* output_shape = TfLiteIntArrayCopy(input->dims);
TF_LITE_ENSURE_TYPES_EQ(context, output->type, input->type);
@ -65,9 +70,11 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
template <typename T, typename TS>
TfLiteStatus ReverseSequenceImpl(TfLiteContext* context, TfLiteNode* node) {
const TfLiteTensor* input = GetInput(context, node, kInputTensor);
const TfLiteTensor* seq_lengths_tensor =
GetInput(context, node, kSeqLengthsTensor);
const TfLiteTensor* input;
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputTensor, &input));
const TfLiteTensor* seq_lengths_tensor;
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kSeqLengthsTensor,
&seq_lengths_tensor));
const TS* seq_lengths = GetTensorData<TS>(seq_lengths_tensor);
auto* params =
@ -86,7 +93,9 @@ TfLiteStatus ReverseSequenceImpl(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_ENSURE(context, seq_lengths[i] <= SizeOfDimension(input, seq_dim));
}
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
TfLiteTensor* output;
TF_LITE_ENSURE_OK(context,
GetOutputSafe(context, node, kOutputTensor, &output));
reference_ops::ReverseSequence<T, TS>(
seq_lengths, seq_dim, batch_dim, GetTensorShape(input),
@ -98,8 +107,9 @@ TfLiteStatus ReverseSequenceImpl(TfLiteContext* context, TfLiteNode* node) {
template <typename T>
TfLiteStatus ReverseSequenceHelper(TfLiteContext* context, TfLiteNode* node) {
const TfLiteTensor* seq_lengths_tensor =
GetInput(context, node, kSeqLengthsTensor);
const TfLiteTensor* seq_lengths_tensor;
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kSeqLengthsTensor,
&seq_lengths_tensor));
switch (seq_lengths_tensor->type) {
case kTfLiteInt32: {
return ReverseSequenceImpl<T, int32_t>(context, node);
@ -119,7 +129,9 @@ TfLiteStatus ReverseSequenceHelper(TfLiteContext* context, TfLiteNode* node) {
}
TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
TfLiteTensor* output;
TF_LITE_ENSURE_OK(context,
GetOutputSafe(context, node, kOutputTensor, &output));
switch (output->type) {
case kTfLiteFloat32: {

View File

@ -73,16 +73,20 @@ static TfLiteStatus InitTemporaryTensors(TfLiteContext* context,
data->fft_double_working_area_id = first_new_index + 1;
// Set up FFT integer working area buffer.
TfLiteTensor* fft_integer_working_area =
GetTemporary(context, node, kFftIntegerWorkingAreaTensor);
TfLiteTensor* fft_integer_working_area;
TF_LITE_ENSURE_OK(
context, GetTemporarySafe(context, node, kFftIntegerWorkingAreaTensor,
&fft_integer_working_area));
fft_integer_working_area->type = kTfLiteInt32;
// If fft_length is not a constant tensor, fft_integer_working_area will be
// set to dynamic later in Prepare.
fft_integer_working_area->allocation_type = kTfLiteArenaRw;
// Set up FFT double working area buffer.
TfLiteTensor* fft_double_working_area =
GetTemporary(context, node, kFftDoubleWorkingAreaTensor);
TfLiteTensor* fft_double_working_area;
TF_LITE_ENSURE_OK(context,
GetTemporarySafe(context, node, kFftDoubleWorkingAreaTensor,
&fft_double_working_area));
// fft_double_working_area is a double tensor. Ideally, double should be
// added into tflite data types. However, since fft_double_working_area is a
// temporary tensor, and there are no ops having double input/output tensors
@ -100,10 +104,13 @@ static TfLiteStatus InitTemporaryTensors(TfLiteContext* context,
TfLiteStatus ResizeOutputandTemporaryTensors(TfLiteContext* context,
TfLiteNode* node) {
const TfLiteTensor* input = GetInput(context, node, kInputTensor);
const TfLiteTensor* input;
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputTensor, &input));
const int num_dims = NumDimensions(input);
TF_LITE_ENSURE(context, num_dims >= 2);
const TfLiteTensor* fft_length = GetInput(context, node, kFftLengthTensor);
const TfLiteTensor* fft_length;
TF_LITE_ENSURE_OK(context,
GetInputSafe(context, node, kFftLengthTensor, &fft_length));
const int32_t* fft_length_data = GetTensorData<int32_t>(fft_length);
// The lib, fft2d, can only handle fft_lengths of power of 2.
TF_LITE_ENSURE(context, IsPowerOfTwo(fft_length_data[0]));
@ -116,15 +123,19 @@ TfLiteStatus ResizeOutputandTemporaryTensors(TfLiteContext* context,
int half_fft_working_length = fft_working_length / 2;
// Resize output tensor.
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
TfLiteTensor* output;
TF_LITE_ENSURE_OK(context,
GetOutputSafe(context, node, kOutputTensor, &output));
TfLiteIntArray* output_shape = TfLiteIntArrayCopy(input->dims);
output_shape->data[num_dims - 2] = fft_length_data[0];
output_shape->data[num_dims - 1] = fft_length_data[1] / 2 + 1;
TF_LITE_ENSURE_STATUS(context->ResizeTensor(context, output, output_shape));
// Resize temporary tensors, fft_integer_working_area.
TfLiteTensor* fft_integer_working_area =
GetTemporary(context, node, kFftIntegerWorkingAreaTensor);
TfLiteTensor* fft_integer_working_area;
TF_LITE_ENSURE_OK(
context, GetTemporarySafe(context, node, kFftIntegerWorkingAreaTensor,
&fft_integer_working_area));
TfLiteIntArray* fft_integer_working_area_shape = TfLiteIntArrayCreate(1);
fft_integer_working_area_shape->data[0] =
2 + static_cast<int>(sqrt(fft_working_length));
@ -132,8 +143,10 @@ TfLiteStatus ResizeOutputandTemporaryTensors(TfLiteContext* context,
fft_integer_working_area_shape));
// Resize temporary tensors, fft_double_working_area.
TfLiteTensor* fft_double_working_area =
GetTemporary(context, node, kFftDoubleWorkingAreaTensor);
TfLiteTensor* fft_double_working_area;
TF_LITE_ENSURE_OK(context,
GetTemporarySafe(context, node, kFftDoubleWorkingAreaTensor,
&fft_double_working_area));
TfLiteIntArray* fft_double_working_area_shape = TfLiteIntArrayCreate(1);
fft_double_working_area_shape->data[0] =
half_fft_working_length + fft_width / 4;
@ -157,7 +170,8 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
// Check type and shape of the input tensor
const TfLiteTensor* input = GetInput(context, node, kInputTensor);
const TfLiteTensor* input;
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputTensor, &input));
TF_LITE_ENSURE(context, NumDimensions(input) >= 2);
if (input->type != kTfLiteFloat32) {
context->ReportError(context,
@ -167,7 +181,9 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
}
// Check type and shape of the fft_length tensor
const TfLiteTensor* fft_length = GetInput(context, node, kFftLengthTensor);
const TfLiteTensor* fft_length;
TF_LITE_ENSURE_OK(context,
GetInputSafe(context, node, kFftLengthTensor, &fft_length));
const RuntimeShape fft_length_shape = GetTensorShape(fft_length);
TF_LITE_ENSURE_EQ(context, NumDimensions(fft_length), 1);
@ -183,17 +199,23 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_ENSURE_STATUS(InitTemporaryTensors(context, node));
// Set output type
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
TfLiteTensor* output;
TF_LITE_ENSURE_OK(context,
GetOutputSafe(context, node, kOutputTensor, &output));
output->type = kTfLiteComplex64;
// Exit early if fft_length is a non-const tensor. Set output tensor and
// temporary tensors to dynamic, so that their tensor sizes can be determined
// in Eval.
if (!IsConstantTensor(fft_length)) {
TfLiteTensor* fft_integer_working_area =
GetTemporary(context, node, kFftIntegerWorkingAreaTensor);
TfLiteTensor* fft_double_working_area =
GetTemporary(context, node, kFftDoubleWorkingAreaTensor);
TfLiteTensor* fft_integer_working_area;
TF_LITE_ENSURE_OK(
context, GetTemporarySafe(context, node, kFftIntegerWorkingAreaTensor,
&fft_integer_working_area));
TfLiteTensor* fft_double_working_area;
TF_LITE_ENSURE_OK(
context, GetTemporarySafe(context, node, kFftDoubleWorkingAreaTensor,
&fft_double_working_area));
SetTensorToDynamic(fft_integer_working_area);
SetTensorToDynamic(fft_double_working_area);
SetTensorToDynamic(output);
@ -325,11 +347,16 @@ void PrepareOutputBuffer(complex<float>* output_data, int fft_height,
}
TfLiteStatus Rfft2dHelper(TfLiteContext* context, TfLiteNode* node) {
const TfLiteTensor* input = GetInput(context, node, kInputTensor);
const TfLiteTensor* input;
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputTensor, &input));
const float* input_data = GetTensorData<float>(input);
const TfLiteTensor* fft_length = GetInput(context, node, kFftLengthTensor);
const TfLiteTensor* fft_length;
TF_LITE_ENSURE_OK(context,
GetInputSafe(context, node, kFftLengthTensor, &fft_length));
const int32_t* fft_length_data = GetTensorData<int32_t>(fft_length);
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
TfLiteTensor* output;
TF_LITE_ENSURE_OK(context,
GetOutputSafe(context, node, kOutputTensor, &output));
complex<float>* output_data = GetTensorData<complex<float>>(output);
int fft_height, fft_width;
@ -358,14 +385,18 @@ TfLiteStatus Rfft2dHelper(TfLiteContext* context, TfLiteNode* node) {
}
// Get buffer for integer working area.
TfLiteTensor* fft_integer_working_area =
GetTemporary(context, node, kFftIntegerWorkingAreaTensor);
TfLiteTensor* fft_integer_working_area;
TF_LITE_ENSURE_OK(
context, GetTemporarySafe(context, node, kFftIntegerWorkingAreaTensor,
&fft_integer_working_area));
int* fft_integer_working_area_data =
GetTensorData<int>(fft_integer_working_area);
// Get buffer for double working area.
TfLiteTensor* fft_double_working_area =
GetTemporary(context, node, kFftDoubleWorkingAreaTensor);
TfLiteTensor* fft_double_working_area;
TF_LITE_ENSURE_OK(context,
GetTemporarySafe(context, node, kFftDoubleWorkingAreaTensor,
&fft_double_working_area));
// Get double value out of the memory of fft_double_working_area_data.
double* fft_double_working_area_data = reinterpret_cast<double*>(
GetTensorData<int64_t>(fft_double_working_area));
@ -393,10 +424,15 @@ TfLiteStatus Rfft2dHelper(TfLiteContext* context, TfLiteNode* node) {
}
TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
const TfLiteTensor* input = GetInput(context, node, kInputTensor);
const TfLiteTensor* fft_length = GetInput(context, node, kFftLengthTensor);
const TfLiteTensor* input;
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputTensor, &input));
const TfLiteTensor* fft_length;
TF_LITE_ENSURE_OK(context,
GetInputSafe(context, node, kFftLengthTensor, &fft_length));
const int32_t* fft_length_data = GetTensorData<int32_t>(fft_length);
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
TfLiteTensor* output;
TF_LITE_ENSURE_OK(context,
GetOutputSafe(context, node, kOutputTensor, &output));
if (output->type != kTfLiteComplex64) {
context->ReportError(context,

View File

@ -30,8 +30,11 @@ constexpr int kInputTensor = 0;
constexpr int kOutputTensor = 0;
TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
const TfLiteTensor* input = GetInput(context, node, kInputTensor);
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
const TfLiteTensor* input;
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputTensor, &input));
TfLiteTensor* output;
TF_LITE_ENSURE_OK(context,
GetOutputSafe(context, node, kOutputTensor, &output));
TF_LITE_ENSURE_EQ(context, NumInputs(node), 1);
TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
TF_LITE_ENSURE_TYPES_EQ(context, input->type, kTfLiteFloat32);
@ -41,8 +44,11 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
}
TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
const TfLiteTensor* input = GetInput(context, node, kInputTensor);
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
const TfLiteTensor* input;
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputTensor, &input));
TfLiteTensor* output;
TF_LITE_ENSURE_OK(context,
GetOutputSafe(context, node, kOutputTensor, &output));
optimized_ops::Round(GetTensorShape(input), GetTensorData<float>(input),
GetTensorShape(output), GetTensorData<float>(output));

View File

@ -74,9 +74,12 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_ENSURE_EQ(context, NumInputs(node), 3);
TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
const TfLiteTensor* indices = GetInput(context, node, kIndices);
const TfLiteTensor* updates = GetInput(context, node, kUpdates);
const TfLiteTensor* shape = GetInput(context, node, kShape);
const TfLiteTensor* indices;
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kIndices, &indices));
const TfLiteTensor* updates;
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kUpdates, &updates));
const TfLiteTensor* shape;
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kShape, &shape));
switch (updates->type) {
case kTfLiteFloat32:
@ -96,7 +99,9 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
return kTfLiteError;
}
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
TfLiteTensor* output;
TF_LITE_ENSURE_OK(context,
GetOutputSafe(context, node, kOutputTensor, &output));
output->type = updates->type;
if (IsConstantTensor(shape)) {
@ -163,10 +168,15 @@ TfLiteStatus EvalScatterNd(TfLiteContext* context, const TfLiteTensor* indices,
}
TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
const TfLiteTensor* indices = GetInput(context, node, kIndices);
const TfLiteTensor* updates = GetInput(context, node, kUpdates);
const TfLiteTensor* shape = GetInput(context, node, kShape);
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
const TfLiteTensor* indices;
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kIndices, &indices));
const TfLiteTensor* updates;
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kUpdates, &updates));
const TfLiteTensor* shape;
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kShape, &shape));
TfLiteTensor* output;
TF_LITE_ENSURE_OK(context,
GetOutputSafe(context, node, kOutputTensor, &output));
switch (indices->type) {
case kTfLiteInt32:

View File

@ -64,11 +64,15 @@ TfLiteStatus ResizeOutputTensor(TfLiteContext* context,
TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_ENSURE_EQ(context, NumInputs(node), 2);
TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
const TfLiteTensor* data = GetInput(context, node, kInputDataTensor);
const TfLiteTensor* segment_ids =
GetInput(context, node, kInputSegmentIdsTensor);
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
const TfLiteTensor* data;
TF_LITE_ENSURE_OK(context,
GetInputSafe(context, node, kInputDataTensor, &data));
const TfLiteTensor* segment_ids;
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputSegmentIdsTensor,
&segment_ids));
TfLiteTensor* output;
TF_LITE_ENSURE_OK(context,
GetOutputSafe(context, node, kOutputTensor, &output));
TF_LITE_ENSURE(context,
data->type == kTfLiteInt32 || data->type == kTfLiteFloat32);
TF_LITE_ENSURE_EQ(context, segment_ids->type, kTfLiteInt32);
@ -82,10 +86,15 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
}
TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
const TfLiteTensor* data = GetInput(context, node, kInputDataTensor);
const TfLiteTensor* segment_ids =
GetInput(context, node, kInputSegmentIdsTensor);
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
const TfLiteTensor* data;
TF_LITE_ENSURE_OK(context,
GetInputSafe(context, node, kInputDataTensor, &data));
const TfLiteTensor* segment_ids;
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputSegmentIdsTensor,
&segment_ids));
TfLiteTensor* output;
TF_LITE_ENSURE_OK(context,
GetOutputSafe(context, node, kOutputTensor, &output));
if (IsDynamicTensor(output)) {
TF_LITE_ENSURE_OK(context,

View File

@ -61,11 +61,18 @@ TfLiteStatus SelectPrepare(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_ENSURE_EQ(context, NumInputs(node), 3);
TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
const TfLiteTensor* input_condition =
GetInput(context, node, kInputTensorCondition);
const TfLiteTensor* input_x = GetInput(context, node, kInputTensorX);
const TfLiteTensor* input_y = GetInput(context, node, kInputTensorY);
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
const TfLiteTensor* input_condition;
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputTensorCondition,
&input_condition));
const TfLiteTensor* input_x;
TF_LITE_ENSURE_OK(context,
GetInputSafe(context, node, kInputTensorX, &input_x));
const TfLiteTensor* input_y;
TF_LITE_ENSURE_OK(context,
GetInputSafe(context, node, kInputTensorY, &input_y));
TfLiteTensor* output;
TF_LITE_ENSURE_OK(context,
GetOutputSafe(context, node, kOutputTensor, &output));
// Input must be bool.
TF_LITE_ENSURE_TYPES_EQ(context, input_condition->type, kTfLiteBool);
@ -111,11 +118,18 @@ TfLiteStatus SelectPrepare(TfLiteContext* context, TfLiteNode* node) {
TfLiteStatus SelectEval(TfLiteContext* context, TfLiteNode* node) {
OpData* data = reinterpret_cast<OpData*>(node->user_data);
const TfLiteTensor* input_condition =
GetInput(context, node, kInputTensorCondition);
const TfLiteTensor* input_x = GetInput(context, node, kInputTensorX);
const TfLiteTensor* input_y = GetInput(context, node, kInputTensorY);
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
const TfLiteTensor* input_condition;
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputTensorCondition,
&input_condition));
const TfLiteTensor* input_x;
TF_LITE_ENSURE_OK(context,
GetInputSafe(context, node, kInputTensorX, &input_x));
const TfLiteTensor* input_y;
TF_LITE_ENSURE_OK(context,
GetInputSafe(context, node, kInputTensorY, &input_y));
TfLiteTensor* output;
TF_LITE_ENSURE_OK(context,
GetOutputSafe(context, node, kOutputTensor, &output));
#define TF_LITE_SELECT(type, op) \
reference_ops::op(GetTensorShape(input_condition), \

View File

@ -40,8 +40,11 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_ENSURE_EQ(context, NumInputs(node), 1);
TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
const TfLiteTensor* input = GetInput(context, node, kInputTensor);
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
const TfLiteTensor* input;
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputTensor, &input));
TfLiteTensor* output;
TF_LITE_ENSURE_OK(context,
GetOutputSafe(context, node, kOutputTensor, &output));
auto* params = reinterpret_cast<TfLiteShapeParams*>(node->builtin_data);
switch (params->out_type) {

View File

@ -35,6 +35,7 @@ limitations under the License.
#include "tensorflow/lite/c/builtin_op_data.h"
#include "tensorflow/lite/c/common.h"
#include "tensorflow/lite/kernels/internal/compatibility.h"
#include "tensorflow/lite/kernels/kernel_util.h"
#include "tensorflow/lite/string_util.h"
@ -48,10 +49,12 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_ENSURE_EQ(context, NumInputs(node), 1);
TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
TF_LITE_ENSURE_TYPES_EQ(context, GetInput(context, node, 0)->type,
kTfLiteString);
TF_LITE_ENSURE_TYPES_EQ(context, GetOutput(context, node, 0)->type,
kTfLiteString);
const TfLiteTensor* input_tensor;
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 0, &input_tensor));
TF_LITE_ENSURE_TYPES_EQ(context, input_tensor->type, kTfLiteString);
TfLiteTensor* output_tensor;
TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, 0, &output_tensor));
TF_LITE_ENSURE_TYPES_EQ(context, output_tensor->type, kTfLiteString);
return kTfLiteOk;
}
@ -91,7 +94,9 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
// Split sentence to words.
std::vector<StringRef> words;
tflite::StringRef strref = tflite::GetString(GetInput(context, node, 0), 0);
const TfLiteTensor* input;
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 0, &input));
tflite::StringRef strref = tflite::GetString(input, 0);
int prev_idx = 0;
for (int i = 1; i < strref.len; i++) {
if (isspace(*(strref.str + i))) {

View File

@ -113,10 +113,15 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_ENSURE_EQ(context, NumInputs(node), 3);
TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
const TfLiteTensor* input = GetInput(context, node, kInputTensor);
const TfLiteTensor* begin = GetInput(context, node, kBeginTensor);
const TfLiteTensor* size = GetInput(context, node, kSizeTensor);
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
const TfLiteTensor* input;
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputTensor, &input));
const TfLiteTensor* begin;
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kBeginTensor, &begin));
const TfLiteTensor* size;
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kSizeTensor, &size));
TfLiteTensor* output;
TF_LITE_ENSURE_OK(context,
GetOutputSafe(context, node, kOutputTensor, &output));
// Ensure validity of input tensor and its dimension.
TF_LITE_ENSURE_TYPES_EQ(context, input->type, output->type);
@ -142,10 +147,15 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
template <KernelType kernel_type>
TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
const TfLiteTensor* input = GetInput(context, node, kInputTensor);
const TfLiteTensor* begin = GetInput(context, node, kBeginTensor);
const TfLiteTensor* size = GetInput(context, node, kSizeTensor);
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
const TfLiteTensor* input;
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputTensor, &input));
const TfLiteTensor* begin;
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kBeginTensor, &begin));
const TfLiteTensor* size;
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kSizeTensor, &size));
TfLiteTensor* output;
TF_LITE_ENSURE_OK(context,
GetOutputSafe(context, node, kOutputTensor, &output));
if (IsDynamicTensor(output)) {
TF_LITE_ENSURE_OK(context,

View File

@ -45,8 +45,11 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_ENSURE_EQ(context, NumInputs(node), 1);
TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
const TfLiteTensor* input = GetInput(context, node, kInputTensor);
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
const TfLiteTensor* input;
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputTensor, &input));
TfLiteTensor* output;
TF_LITE_ENSURE_OK(context,
GetOutputSafe(context, node, kOutputTensor, &output));
TF_LITE_ENSURE_EQ(context, NumDimensions(input), 4);
@ -80,8 +83,11 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
auto* params =
reinterpret_cast<TfLiteSpaceToDepthParams*>(node->builtin_data);
const TfLiteTensor* input = GetInput(context, node, kInputTensor);
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
const TfLiteTensor* input;
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputTensor, &input));
TfLiteTensor* output;
TF_LITE_ENSURE_OK(context,
GetOutputSafe(context, node, kOutputTensor, &output));
#define TF_LITE_SPACE_TO_DEPTH(type, scalar) \
tflite::SpaceToDepthParams op_params; \

View File

@ -143,12 +143,18 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_ENSURE_EQ(context, NumInputs(node), 4);
TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
const TfLiteTensor* indices = GetInput(context, node, kIndicesTensor);
const TfLiteTensor* output_shape =
GetInput(context, node, kOutputShapeTensor);
const TfLiteTensor* values = GetInput(context, node, kValueInputTensor);
const TfLiteTensor* default_value =
GetInput(context, node, kDefaultValueTensor);
const TfLiteTensor* indices;
TF_LITE_ENSURE_OK(context,
GetInputSafe(context, node, kIndicesTensor, &indices));
const TfLiteTensor* output_shape;
TF_LITE_ENSURE_OK(
context, GetInputSafe(context, node, kOutputShapeTensor, &output_shape));
const TfLiteTensor* values;
TF_LITE_ENSURE_OK(context,
GetInputSafe(context, node, kValueInputTensor, &values));
const TfLiteTensor* default_value;
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kDefaultValueTensor,
&default_value));
// TODO(renjieliu): Handle validate_indices.
@ -178,7 +184,9 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_ENSURE_OK(
context, CheckDimensionsMatch(context, indices, output_shape, values));
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
TfLiteTensor* output;
TF_LITE_ENSURE_OK(context,
GetOutputSafe(context, node, kOutputTensor, &output));
output->type = values->type;
TF_LITE_ENSURE_EQ(context, NumDimensions(output_shape), 1);
@ -191,13 +199,21 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
template <typename T, typename TI>
TfLiteStatus SparseToDenseImpl(TfLiteContext* context, TfLiteNode* node) {
const TfLiteTensor* indices = GetInput(context, node, kIndicesTensor);
const TfLiteTensor* output_shape =
GetInput(context, node, kOutputShapeTensor);
const TfLiteTensor* values = GetInput(context, node, kValueInputTensor);
const TfLiteTensor* default_value =
GetInput(context, node, kDefaultValueTensor);
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
const TfLiteTensor* indices;
TF_LITE_ENSURE_OK(context,
GetInputSafe(context, node, kIndicesTensor, &indices));
const TfLiteTensor* output_shape;
TF_LITE_ENSURE_OK(
context, GetInputSafe(context, node, kOutputShapeTensor, &output_shape));
const TfLiteTensor* values;
TF_LITE_ENSURE_OK(context,
GetInputSafe(context, node, kValueInputTensor, &values));
const TfLiteTensor* default_value;
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kDefaultValueTensor,
&default_value));
TfLiteTensor* output;
TF_LITE_ENSURE_OK(context,
GetOutputSafe(context, node, kOutputTensor, &output));
if (IsDynamicTensor(output)) {
TF_LITE_ENSURE_OK(context,
@ -238,8 +254,12 @@ TfLiteStatus EvalForIndexType(TfLiteContext* context, TfLiteNode* node,
}
TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
const TfLiteTensor* indices = GetInput(context, node, kIndicesTensor);
const TfLiteTensor* values = GetInput(context, node, kValueInputTensor);
const TfLiteTensor* indices;
TF_LITE_ENSURE_OK(context,
GetInputSafe(context, node, kIndicesTensor, &indices));
const TfLiteTensor* values;
TF_LITE_ENSURE_OK(context,
GetInputSafe(context, node, kValueInputTensor, &values));
switch (values->type) {
case kTfLiteFloat32:

View File

@ -41,7 +41,9 @@ struct OpContext {
TfLiteStatus UseDynamicOutputTensors(TfLiteContext* context, TfLiteNode* node) {
for (int i = 0; i < NumOutputs(node); ++i) {
SetTensorToDynamic(GetOutput(context, node, i));
TfLiteTensor* tensor;
TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, i, &tensor));
SetTensorToDynamic(tensor);
}
return kTfLiteOk;
}
@ -65,7 +67,8 @@ TfLiteStatus ResizeOutputTensors(TfLiteContext* context, TfLiteNode* node,
for (int i = 0; i < NumOutputs(node); ++i) {
TfLiteIntArray* output_dims = TfLiteIntArrayCopy(input->dims);
output_dims->data[axis_value] = slice_size;
TfLiteTensor* output = GetOutput(context, node, i);
TfLiteTensor* output;
TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, i, &output));
TF_LITE_ENSURE_STATUS(context->ResizeTensor(context, output, output_dims));
}
@ -85,7 +88,9 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
input_type == kTfLiteInt8 || input_type == kTfLiteInt16 ||
input_type == kTfLiteInt32);
for (int i = 0; i < NumOutputs(node); ++i) {
GetOutput(context, node, i)->type = input_type;
TfLiteTensor* tensor;
TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, i, &tensor));
tensor->type = input_type;
}
// If we know the contents of the 'axis' tensor, resize all outputs.

View File

@ -45,7 +45,9 @@ struct OpContext {
TfLiteStatus UseDynamicOutputTensors(TfLiteContext* context, TfLiteNode* node) {
for (int i = 0; i < NumOutputs(node); ++i) {
SetTensorToDynamic(GetOutput(context, node, i));
TfLiteTensor* tensor;
TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, i, &tensor));
SetTensorToDynamic(tensor);
}
return kTfLiteOk;
}
@ -113,7 +115,8 @@ TfLiteStatus ResizeOutputTensors(TfLiteContext* context, TfLiteNode* node,
for (int i = 0; i < NumOutputs(node); ++i) {
TfLiteIntArray* output_dims = TfLiteIntArrayCopy(input->dims);
output_dims->data[axis_value] = size_splits_vector.at(i);
TfLiteTensor* output = GetOutput(context, node, i);
TfLiteTensor* output;
TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, i, &output));
TF_LITE_ENSURE_STATUS(context->ResizeTensor(context, output, output_dims));
}
@ -133,7 +136,9 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
input_type == kTfLiteInt16 || input_type == kTfLiteInt32 ||
input_type == kTfLiteInt64 || input_type == kTfLiteInt8);
for (int i = 0; i < NumOutputs(node); ++i) {
GetOutput(context, node, i)->type = input_type;
TfLiteTensor* tensor;
TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, i, &tensor));
tensor->type = input_type;
}
auto size_splits = op_context.size_splits;

View File

@ -60,9 +60,15 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_ENSURE_EQ(context, NumInputs(node), 2);
TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
const TfLiteTensor* input1 = GetInput(context, node, kInputTensor1);
const TfLiteTensor* input2 = GetInput(context, node, kInputTensor2);
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
const TfLiteTensor* input1;
TF_LITE_ENSURE_OK(context,
GetInputSafe(context, node, kInputTensor1, &input1));
const TfLiteTensor* input2;
TF_LITE_ENSURE_OK(context,
GetInputSafe(context, node, kInputTensor2, &input2));
TfLiteTensor* output;
TF_LITE_ENSURE_OK(context,
GetOutputSafe(context, node, kOutputTensor, &output));
TF_LITE_ENSURE_TYPES_EQ(context, input1->type, input2->type);
output->type = input2->type;
@ -101,9 +107,15 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
OpData* data = reinterpret_cast<OpData*>(node->user_data);
ruy::profiler::ScopeLabel label("SquaredDifference");
const TfLiteTensor* input1 = GetInput(context, node, kInputTensor1);
const TfLiteTensor* input2 = GetInput(context, node, kInputTensor2);
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
const TfLiteTensor* input1;
TF_LITE_ENSURE_OK(context,
GetInputSafe(context, node, kInputTensor1, &input1));
const TfLiteTensor* input2;
TF_LITE_ENSURE_OK(context,
GetInputSafe(context, node, kInputTensor2, &input2));
TfLiteTensor* output;
TF_LITE_ENSURE_OK(context,
GetOutputSafe(context, node, kOutputTensor, &output));
if (output->type == kTfLiteFloat32) {
EvalSquaredDifference<float>(context, node, data, input1, input2, output);

View File

@ -217,9 +217,15 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_ENSURE_EQ(context, NumInputs(node), 2);
TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
const TfLiteTensor* input1 = GetInput(context, node, kInputTensor1);
const TfLiteTensor* input2 = GetInput(context, node, kInputTensor2);
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
const TfLiteTensor* input1;
TF_LITE_ENSURE_OK(context,
GetInputSafe(context, node, kInputTensor1, &input1));
const TfLiteTensor* input2;
TF_LITE_ENSURE_OK(context,
GetInputSafe(context, node, kInputTensor2, &input2));
TfLiteTensor* output;
TF_LITE_ENSURE_OK(context,
GetOutputSafe(context, node, kOutputTensor, &output));
TF_LITE_ENSURE_TYPES_EQ(context, input1->type, input2->type);
output->type = input2->type;
@ -435,9 +441,15 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
auto* params = reinterpret_cast<TfLiteSubParams*>(node->builtin_data);
OpData* data = reinterpret_cast<OpData*>(node->user_data);
const TfLiteTensor* input1 = GetInput(context, node, kInputTensor1);
const TfLiteTensor* input2 = GetInput(context, node, kInputTensor2);
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
const TfLiteTensor* input1;
TF_LITE_ENSURE_OK(context,
GetInputSafe(context, node, kInputTensor1, &input1));
const TfLiteTensor* input2;
TF_LITE_ENSURE_OK(context,
GetInputSafe(context, node, kInputTensor2, &input2));
TfLiteTensor* output;
TF_LITE_ENSURE_OK(context,
GetOutputSafe(context, node, kOutputTensor, &output));
if (output->type == kTfLiteFloat32 || output->type == kTfLiteInt32 ||
output->type == kTfLiteInt64) {

View File

@ -82,11 +82,14 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_ENSURE_EQ(context, node->outputs->size, 1);
TF_LITE_ENSURE_EQ(context, node->inputs->size, 5);
const TfLiteTensor* input = GetInput(context, node, kInputTensor);
const TfLiteTensor* weights_feature =
GetInput(context, node, kWeightsFeatureTensor);
const TfLiteTensor* weights_time =
GetInput(context, node, kWeightsTimeTensor);
const TfLiteTensor* input;
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputTensor, &input));
const TfLiteTensor* weights_feature;
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kWeightsFeatureTensor,
&weights_feature));
const TfLiteTensor* weights_time;
TF_LITE_ENSURE_OK(
context, GetInputSafe(context, node, kWeightsTimeTensor, &weights_time));
TF_LITE_ENSURE(context,
input->type == kTfLiteFloat32 || input->type == kTfLiteInt8);
@ -108,8 +111,11 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_ENSURE_EQ(context, bias->dims->data[0], num_units);
}
const TfLiteTensor* state = GetInput(context, node, kStateTensor);
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
const TfLiteTensor* state;
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kStateTensor, &state));
TfLiteTensor* output;
TF_LITE_ENSURE_OK(context,
GetOutputSafe(context, node, kOutputTensor, &output));
// Check the shape of input state tensors.
TF_LITE_ENSURE_EQ(context, NumDimensions(state), 2);
@ -143,7 +149,9 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
scratch_size_array->data[0] = batch_size;
scratch_size_array->data[1] = num_filters;
TfLiteTensor* scratch_tensor = GetTemporary(context, node, /*index=*/0);
TfLiteTensor* scratch_tensor;
TF_LITE_ENSURE_OK(
context, GetTemporarySafe(context, node, /*index=*/0, &scratch_tensor));
// The scratch buffer is of type int32 for full integer svdf and it's of type
// float32 for hybrid and float case.
@ -161,7 +169,9 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
// Tell interpreter to allocate temporary tensors to store quantized values
// of input tensors.
node->temporaries->data[1] = scratch_tensor_index + 1;
TfLiteTensor* input_quantized = GetTemporary(context, node, /*index=*/1);
TfLiteTensor* input_quantized;
TF_LITE_ENSURE_OK(context, GetTemporarySafe(context, node, /*index=*/1,
&input_quantized));
input_quantized->type = weights_feature->type;
input_quantized->allocation_type = kTfLiteArenaRw;
if (!TfLiteIntArrayEqual(input_quantized->dims, input->dims)) {
@ -172,7 +182,9 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
// Tell interpreter to allocate temporary tensors to store scaling factors.
node->temporaries->data[2] = scratch_tensor_index + 2;
TfLiteTensor* scaling_factors = GetTemporary(context, node, /*index=*/2);
TfLiteTensor* scaling_factors;
TF_LITE_ENSURE_OK(context, GetTemporarySafe(context, node, /*index=*/2,
&scaling_factors));
scaling_factors->type = kTfLiteFloat32;
scaling_factors->allocation_type = kTfLiteArenaRw;
int scaling_dims[1] = {batch_size};
@ -186,7 +198,9 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
// Used to store dequantized weights_time matrix for hybrid computation of
// matmul(state, weights_time), which occurs in floating point.
node->temporaries->data[3] = scratch_tensor_index + 3;
TfLiteTensor* float_weights_time = GetTemporary(context, node, /*index=*/3);
TfLiteTensor* float_weights_time;
TF_LITE_ENSURE_OK(context, GetTemporarySafe(context, node, /*index=*/3,
&float_weights_time));
float_weights_time->type = kTfLiteFloat32;
// Persistent so that we can compute the dequantized weights only once.
float_weights_time->allocation_type = kTfLiteArenaRwPersistent;
@ -199,7 +213,9 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
}
node->temporaries->data[4] = scratch_tensor_index + 4;
TfLiteTensor* zero_points = GetTemporary(context, node, /*index=*/4);
TfLiteTensor* zero_points;
TF_LITE_ENSURE_OK(
context, GetTemporarySafe(context, node, /*index=*/4, &zero_points));
zero_points->type = kTfLiteFloat32;
zero_points->allocation_type = kTfLiteArenaRw;
int zero_points_dims[1] = {batch_size};
@ -211,7 +227,9 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
}
node->temporaries->data[5] = scratch_tensor_index + 5;
TfLiteTensor* row_sums = GetTemporary(context, node, /*index=*/5);
TfLiteTensor* row_sums;
TF_LITE_ENSURE_OK(context,
GetTemporarySafe(context, node, /*index=*/5, &row_sums));
row_sums->type = kTfLiteFloat32;
row_sums->allocation_type = kTfLiteArenaRwPersistent;
int row_sums_dims[1] = {num_filters};
@ -228,7 +246,9 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
output_temp_size_array->data[0] = num_units;
output_temp_size_array->data[1] = batch_size;
node->temporaries->data[1] = scratch_tensor_index + 1;
TfLiteTensor* output_temp = GetTemporary(context, node, /*index=*/1);
TfLiteTensor* output_temp;
TF_LITE_ENSURE_OK(
context, GetTemporarySafe(context, node, /*index=*/1, &output_temp));
output_temp->type = kTfLiteInt32;
output_temp->allocation_type = kTfLiteArenaRw;
TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, output_temp,
@ -263,17 +283,24 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
auto* params = reinterpret_cast<TfLiteSVDFParams*>(node->builtin_data);
OpData* op_data = reinterpret_cast<OpData*>(node->user_data);
const TfLiteTensor* input = GetInput(context, node, kInputTensor);
const TfLiteTensor* weights_feature =
GetInput(context, node, kWeightsFeatureTensor);
const TfLiteTensor* weights_time =
GetInput(context, node, kWeightsTimeTensor);
const TfLiteTensor* input;
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputTensor, &input));
const TfLiteTensor* weights_feature;
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kWeightsFeatureTensor,
&weights_feature));
const TfLiteTensor* weights_time;
TF_LITE_ENSURE_OK(
context, GetInputSafe(context, node, kWeightsTimeTensor, &weights_time));
const TfLiteTensor* bias = GetOptionalInputTensor(context, node, kBiasTensor);
TfLiteTensor* scratch = GetTemporary(context, node, /*index=*/0);
TfLiteTensor* scratch;
TF_LITE_ENSURE_OK(context,
GetTemporarySafe(context, node, /*index=*/0, &scratch));
TfLiteTensor* state = GetVariableInput(context, node, kStateTensor);
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
TfLiteTensor* output;
TF_LITE_ENSURE_OK(context,
GetOutputSafe(context, node, kOutputTensor, &output));
switch (weights_feature->type) {
case kTfLiteFloat32: {
@ -286,14 +313,21 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
case kTfLiteUInt8:
case kTfLiteInt8: {
if (input->type == kTfLiteFloat32) {
TfLiteTensor* input_quantized =
GetTemporary(context, node, /*index=*/1);
TfLiteTensor* scaling_factors =
GetTemporary(context, node, /*index=*/2);
TfLiteTensor* float_weights_time =
GetTemporary(context, node, /*index=*/3);
TfLiteTensor* zero_points = GetTemporary(context, node, /*index=*/4);
TfLiteTensor* row_sums = GetTemporary(context, node, /*index=*/5);
TfLiteTensor* input_quantized;
TF_LITE_ENSURE_OK(context, GetTemporarySafe(context, node, /*index=*/1,
&input_quantized));
TfLiteTensor* scaling_factors;
TF_LITE_ENSURE_OK(context, GetTemporarySafe(context, node, /*index=*/2,
&scaling_factors));
TfLiteTensor* float_weights_time;
TF_LITE_ENSURE_OK(context, GetTemporarySafe(context, node, /*index=*/3,
&float_weights_time));
TfLiteTensor* zero_points;
TF_LITE_ENSURE_OK(context, GetTemporarySafe(context, node, /*index=*/4,
&zero_points));
TfLiteTensor* row_sums;
TF_LITE_ENSURE_OK(
context, GetTemporarySafe(context, node, /*index=*/5, &row_sums));
// Dequantize weights time.
// TODO(alanchiao): this dequantization initialization only needs to
// happen once per model and should theoretically be placed in either
@ -322,7 +356,9 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
input->quantization.params);
auto* output_params = reinterpret_cast<TfLiteAffineQuantization*>(
output->quantization.params);
TfLiteTensor* output_temp = GetTemporary(context, node, /*index=*/1);
TfLiteTensor* output_temp;
TF_LITE_ENSURE_OK(context, GetTemporarySafe(context, node, /*index=*/1,
&output_temp));
// Currently supports only ReLU.
// TODO(jianlijianli): support other activations.

View File

@ -49,9 +49,14 @@ TfLiteIntArray* MultiplyShapeDims(const TfLiteIntArray& shape,
}
TfLiteStatus ResizeOutput(TfLiteContext* context, TfLiteNode* node) {
const TfLiteTensor* input = GetInput(context, node, kInputTensor);
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
const TfLiteTensor* multipliers = GetInput(context, node, kInputMultipliers);
const TfLiteTensor* input;
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputTensor, &input));
TfLiteTensor* output;
TF_LITE_ENSURE_OK(context,
GetOutputSafe(context, node, kOutputTensor, &output));
const TfLiteTensor* multipliers;
TF_LITE_ENSURE_OK(
context, GetInputSafe(context, node, kInputMultipliers, &multipliers));
const int num_dimensions = NumDimensions(input);
const int num_multipliers = NumElements(multipliers);
@ -208,12 +213,17 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_ENSURE_EQ(context, NumInputs(node), 2);
TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
const TfLiteTensor* input = GetInput(context, node, kInputTensor);
const TfLiteTensor* input;
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputTensor, &input));
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
TfLiteTensor* output;
TF_LITE_ENSURE_OK(context,
GetOutputSafe(context, node, kOutputTensor, &output));
TF_LITE_ENSURE_TYPES_EQ(context, input->type, output->type);
const TfLiteTensor* multipliers = GetInput(context, node, kInputMultipliers);
const TfLiteTensor* multipliers;
TF_LITE_ENSURE_OK(
context, GetInputSafe(context, node, kInputMultipliers, &multipliers));
// Only int32 and int64 multipliers type is supported.
if (multipliers->type != kTfLiteInt32 && multipliers->type != kTfLiteInt64) {
context->ReportError(context,
@ -231,9 +241,14 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
}
TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
const TfLiteTensor* input = GetInput(context, node, kInputTensor);
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
const TfLiteTensor* multipliers = GetInput(context, node, kInputMultipliers);
const TfLiteTensor* input;
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputTensor, &input));
TfLiteTensor* output;
TF_LITE_ENSURE_OK(context,
GetOutputSafe(context, node, kOutputTensor, &output));
const TfLiteTensor* multipliers;
TF_LITE_ENSURE_OK(
context, GetInputSafe(context, node, kInputMultipliers, &multipliers));
if (IsDynamicTensor(output)) {
TF_LITE_ENSURE_OK(context, ResizeOutput(context, node));

View File

@ -35,14 +35,16 @@ constexpr int kOutputIndexes = 1;
namespace {
TfLiteStatus ResizeOutput(TfLiteContext* context, TfLiteNode* node) {
const TfLiteTensor* top_k = GetInput(context, node, kInputTopK);
const TfLiteTensor* top_k;
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputTopK, &top_k));
// INT32 number of top results is supported.
TF_LITE_ENSURE_TYPES_EQ(context, top_k->type, kTfLiteInt32);
// Check that the tensor contains only one value.
TF_LITE_ENSURE_EQ(context, NumElements(top_k), 1);
const int32 k = *GetTensorData<int32_t>(top_k);
const TfLiteTensor* input = GetInput(context, node, kInputTensor);
const TfLiteTensor* input;
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputTensor, &input));
const int num_dimensions = NumDimensions(input);
// Check that input has one or more dimensions.
TF_LITE_ENSURE_MSG(context, input->dims->size >= 1,
@ -59,8 +61,12 @@ TfLiteStatus ResizeOutput(TfLiteContext* context, TfLiteNode* node) {
}
output_indexes_shape->data[num_dimensions - 1] = k;
output_values_shape->data[num_dimensions - 1] = k;
TfLiteTensor* output_indexes = GetOutput(context, node, kOutputIndexes);
TfLiteTensor* output_values = GetOutput(context, node, kOutputValues);
TfLiteTensor* output_indexes;
TF_LITE_ENSURE_OK(
context, GetOutputSafe(context, node, kOutputIndexes, &output_indexes));
TfLiteTensor* output_values;
TF_LITE_ENSURE_OK(
context, GetOutputSafe(context, node, kOutputValues, &output_values));
// Force output types.
output_indexes->type = kTfLiteInt32;
output_values->type = input->type;
@ -195,19 +201,27 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_ENSURE_EQ(context, NumInputs(node), 2);
TF_LITE_ENSURE_EQ(context, NumOutputs(node), 2);
const TfLiteTensor* input = GetInput(context, node, kInputTensor);
TfLiteTensor* output_values = GetOutput(context, node, kOutputValues);
const TfLiteTensor* input;
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputTensor, &input));
TfLiteTensor* output_values;
TF_LITE_ENSURE_OK(
context, GetOutputSafe(context, node, kOutputValues, &output_values));
TF_LITE_ENSURE_TYPES_EQ(context, input->type, output_values->type);
const TfLiteTensor* top_k = GetInput(context, node, kInputTopK);
const TfLiteTensor* top_k;
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputTopK, &top_k));
TF_LITE_ENSURE_TYPES_EQ(context, top_k->type, kTfLiteInt32);
// Set output dynamic if the input is not const.
if (IsConstantTensor(top_k)) {
TF_LITE_ENSURE_OK(context, ResizeOutput(context, node));
} else {
TfLiteTensor* output_indexes = GetOutput(context, node, kOutputIndexes);
TfLiteTensor* output_values = GetOutput(context, node, kOutputValues);
TfLiteTensor* output_indexes;
TF_LITE_ENSURE_OK(
context, GetOutputSafe(context, node, kOutputIndexes, &output_indexes));
TfLiteTensor* output_values;
TF_LITE_ENSURE_OK(
context, GetOutputSafe(context, node, kOutputValues, &output_values));
SetTensorToDynamic(output_indexes);
SetTensorToDynamic(output_values);
}
@ -215,16 +229,22 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
}
TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
TfLiteTensor* output_values = GetOutput(context, node, kOutputValues);
TfLiteTensor* output_indexes = GetOutput(context, node, kOutputIndexes);
TfLiteTensor* output_values;
TF_LITE_ENSURE_OK(
context, GetOutputSafe(context, node, kOutputValues, &output_values));
TfLiteTensor* output_indexes;
TF_LITE_ENSURE_OK(
context, GetOutputSafe(context, node, kOutputIndexes, &output_indexes));
if (IsDynamicTensor(output_values)) {
TF_LITE_ENSURE_OK(context, ResizeOutput(context, node));
}
const TfLiteTensor* top_k = GetInput(context, node, kInputTopK);
const TfLiteTensor* top_k;
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputTopK, &top_k));
const int32 k = top_k->data.i32[0];
// The tensor can have more than 2 dimensions or even be a vector, the code
// anyway calls the internal dimension as row;
const TfLiteTensor* input = GetInput(context, node, kInputTensor);
const TfLiteTensor* input;
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputTensor, &input));
const int32 row_size = input->dims->data[input->dims->size - 1];
int32 num_rows = 1;
for (int i = 0; i < input->dims->size - 1; ++i) {

View File

@ -250,13 +250,20 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
// Retrieve tensors
const TfLiteTensor* output_shape =
GetInput(context, node, kOutputShapeTensor);
const TfLiteTensor* weights = GetInput(context, node, kWeightsTensor);
const TfLiteTensor* input = GetInput(context, node, kDataInputTensor);
const TfLiteTensor* output_shape;
TF_LITE_ENSURE_OK(
context, GetInputSafe(context, node, kOutputShapeTensor, &output_shape));
const TfLiteTensor* weights;
TF_LITE_ENSURE_OK(context,
GetInputSafe(context, node, kWeightsTensor, &weights));
const TfLiteTensor* input;
TF_LITE_ENSURE_OK(context,
GetInputSafe(context, node, kDataInputTensor, &input));
const TfLiteTensor* bias = nullptr;
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
TfLiteTensor* output;
TF_LITE_ENSURE_OK(context,
GetOutputSafe(context, node, kOutputTensor, &output));
// Tensor sanity checks
TF_LITE_ENSURE_EQ(context, NumDimensions(output_shape), 1);
@ -306,7 +313,9 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
TfLiteTensor* col2im = nullptr;
if (data->has_col2im) {
node->temporaries->data[data->col2im_index] = data->col2im_id;
col2im = GetTemporary(context, node, user_data->col2im_index);
TF_LITE_ENSURE_OK(
context,
GetTemporarySafe(context, node, user_data->col2im_index, &col2im));
}
if (!IsConstantTensor(output_shape)) {
@ -326,8 +335,11 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
if (data->weights_are_transposed) {
node->temporaries->data[data->transposed_weights_index] =
data->transposed_weights_id;
TfLiteTensor* transposed_weights =
GetTemporary(context, node, user_data->transposed_weights_index);
TfLiteTensor* transposed_weights;
TF_LITE_ENSURE_OK(
context,
GetTemporarySafe(context, node, user_data->transposed_weights_index,
&transposed_weights));
if (!IsConstantTensor(weights)) {
SetTensorToDynamic(transposed_weights);
} else {
@ -339,8 +351,10 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
input->type == kTfLiteInt16) {
node->temporaries->data[data->scratch_tensor_index] =
data->scratch_tensor_id;
TfLiteTensor* scratch_buffer =
GetTemporary(context, node, data->scratch_tensor_index);
TfLiteTensor* scratch_buffer;
TF_LITE_ENSURE_OK(
context, GetTemporarySafe(context, node, data->scratch_tensor_index,
&scratch_buffer));
if (input->type == kTfLiteInt16) {
scratch_buffer->type = kTfLiteInt64;
} else {
@ -549,15 +563,22 @@ void EvalQuantizedPerChannel16x8(
template <KernelType kernel_type>
TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
// Retrieve tensors (All should be allocated by now)
const TfLiteTensor* output_shape =
GetInput(context, node, kOutputShapeTensor);
const TfLiteTensor* weights = GetInput(context, node, kWeightsTensor);
const TfLiteTensor* input = GetInput(context, node, kDataInputTensor);
const TfLiteTensor* output_shape;
TF_LITE_ENSURE_OK(
context, GetInputSafe(context, node, kOutputShapeTensor, &output_shape));
const TfLiteTensor* weights;
TF_LITE_ENSURE_OK(context,
GetInputSafe(context, node, kWeightsTensor, &weights));
const TfLiteTensor* input;
TF_LITE_ENSURE_OK(context,
GetInputSafe(context, node, kDataInputTensor, &input));
const TfLiteTensor* bias =
(NumInputs(node) == 4)
? GetOptionalInputTensor(context, node, kBiasTensor)
: nullptr;
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
TfLiteTensor* output;
TF_LITE_ENSURE_OK(context,
GetOutputSafe(context, node, kOutputTensor, &output));
OpData* data = reinterpret_cast<OpData*>(node->user_data);
TfLiteTensor* col2im = data->has_col2im
? GetTemporary(context, node, data->col2im_index)
@ -604,8 +625,10 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
break;
}
case kTfLiteUInt8: {
TfLiteTensor* scratch_buffer =
GetTemporary(context, node, data->scratch_tensor_index);
TfLiteTensor* scratch_buffer;
TF_LITE_ENSURE_OK(
context, GetTemporarySafe(context, node, data->scratch_tensor_index,
&scratch_buffer));
if (IsDynamicTensor(scratch_buffer)) {
TF_LITE_ENSURE_OK(context,
ResizeTensor(context, output_shape, scratch_buffer));
@ -621,8 +644,10 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
break;
}
case kTfLiteInt8: {
TfLiteTensor* scratch_buffer =
GetTemporary(context, node, data->scratch_tensor_index);
TfLiteTensor* scratch_buffer;
TF_LITE_ENSURE_OK(
context, GetTemporarySafe(context, node, data->scratch_tensor_index,
&scratch_buffer));
if (IsDynamicTensor(scratch_buffer)) {
TF_LITE_ENSURE_OK(context,
ResizeTensor(context, output_shape, scratch_buffer));
@ -636,8 +661,10 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
break;
}
case kTfLiteInt16: {
TfLiteTensor* scratch_buffer =
GetTemporary(context, node, data->scratch_tensor_index);
TfLiteTensor* scratch_buffer;
TF_LITE_ENSURE_OK(
context, GetTemporarySafe(context, node, data->scratch_tensor_index,
&scratch_buffer));
if (IsDynamicTensor(scratch_buffer)) {
TF_LITE_ENSURE_OK(context,
ResizeTensor(context, output_shape, scratch_buffer));

View File

@ -20,6 +20,7 @@ limitations under the License.
#include "tensorflow/lite/c/builtin_op_data.h"
#include "tensorflow/lite/c/common.h"
#include "tensorflow/lite/kernels/cpu_backend_context.h"
#include "tensorflow/lite/kernels/internal/compatibility.h"
#include "tensorflow/lite/kernels/internal/kernel_utils.h"
#include "tensorflow/lite/kernels/internal/tensor_utils.h"
#include "tensorflow/lite/kernels/kernel_util.h"
@ -88,14 +89,19 @@ TfLiteStatus CheckInputTensorDimensions(TfLiteContext* context,
TF_LITE_ENSURE_EQ(context, input_to_input_weights->dims->data[1], n_input);
}
const TfLiteTensor* input_to_forget_weights =
GetInput(context, node, lstm::full::kInputToForgetWeightsTensor);
const TfLiteTensor* input_to_forget_weights;
TF_LITE_ENSURE_OK(
context,
GetInputSafe(context, node, lstm::full::kInputToForgetWeightsTensor,
&input_to_forget_weights));
TF_LITE_ENSURE_EQ(context, input_to_forget_weights->dims->size, 2);
TF_LITE_ENSURE_EQ(context, input_to_forget_weights->dims->data[0], n_cell);
TF_LITE_ENSURE_EQ(context, input_to_forget_weights->dims->data[1], n_input);
const TfLiteTensor* input_to_cell_weights =
GetInput(context, node, lstm::full::kInputToCellWeightsTensor);
const TfLiteTensor* input_to_cell_weights;
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node,
lstm::full::kInputToCellWeightsTensor,
&input_to_cell_weights));
TF_LITE_ENSURE_EQ(context, input_to_cell_weights->dims->size, 2);
TF_LITE_ENSURE_EQ(context, input_to_cell_weights->dims->data[0], n_cell);
TF_LITE_ENSURE_EQ(context, input_to_cell_weights->dims->data[1], n_input);
@ -110,16 +116,22 @@ TfLiteStatus CheckInputTensorDimensions(TfLiteContext* context,
n_output);
}
const TfLiteTensor* recurrent_to_forget_weights =
GetInput(context, node, lstm::full::kRecurrentToForgetWeightsTensor);
const TfLiteTensor* recurrent_to_forget_weights;
TF_LITE_ENSURE_OK(
context,
GetInputSafe(context, node, lstm::full::kRecurrentToForgetWeightsTensor,
&recurrent_to_forget_weights));
TF_LITE_ENSURE_EQ(context, recurrent_to_forget_weights->dims->size, 2);
TF_LITE_ENSURE_EQ(context, recurrent_to_forget_weights->dims->data[0],
n_cell);
TF_LITE_ENSURE_EQ(context, recurrent_to_forget_weights->dims->data[1],
n_output);
const TfLiteTensor* recurrent_to_cell_weights =
GetInput(context, node, lstm::full::kRecurrentToCellWeightsTensor);
const TfLiteTensor* recurrent_to_cell_weights;
TF_LITE_ENSURE_OK(
context,
GetInputSafe(context, node, lstm::full::kRecurrentToCellWeightsTensor,
&recurrent_to_cell_weights));
TF_LITE_ENSURE_EQ(context, recurrent_to_cell_weights->dims->size, 2);
TF_LITE_ENSURE_EQ(context, recurrent_to_cell_weights->dims->data[0], n_cell);
TF_LITE_ENSURE_EQ(context, recurrent_to_cell_weights->dims->data[1],
@ -176,18 +188,24 @@ TfLiteStatus CheckInputTensorDimensions(TfLiteContext* context,
TF_LITE_ENSURE_EQ(context, input_gate_bias->dims->data[0], n_cell);
}
const TfLiteTensor* forget_gate_bias =
GetInput(context, node, lstm::full::kForgetGateBiasTensor);
const TfLiteTensor* forget_gate_bias;
TF_LITE_ENSURE_OK(
context, GetInputSafe(context, node, lstm::full::kForgetGateBiasTensor,
&forget_gate_bias));
TF_LITE_ENSURE_EQ(context, forget_gate_bias->dims->size, 1);
TF_LITE_ENSURE_EQ(context, forget_gate_bias->dims->data[0], n_cell);
const TfLiteTensor* cell_gate_bias =
GetInput(context, node, lstm::full::kCellGateBiasTensor);
const TfLiteTensor* cell_gate_bias;
TF_LITE_ENSURE_OK(context,
GetInputSafe(context, node, lstm::full::kCellGateBiasTensor,
&cell_gate_bias));
TF_LITE_ENSURE_EQ(context, cell_gate_bias->dims->size, 1);
TF_LITE_ENSURE_EQ(context, cell_gate_bias->dims->data[0], n_cell);
const TfLiteTensor* output_gate_bias =
GetInput(context, node, lstm::full::kOutputGateBiasTensor);
const TfLiteTensor* output_gate_bias;
TF_LITE_ENSURE_OK(
context, GetInputSafe(context, node, lstm::full::kOutputGateBiasTensor,
&output_gate_bias));
TF_LITE_ENSURE_EQ(context, output_gate_bias->dims->size, 1);
TF_LITE_ENSURE_EQ(context, output_gate_bias->dims->data[0], n_cell);
@ -229,27 +247,33 @@ TfLiteStatus CheckInputTensorDimensions(TfLiteContext* context,
kTfLiteFloat32);
}
const TfLiteTensor* forget_layer_norm_coefficients =
GetInput(context, node, lstm::full::kForgetLayerNormCoefficientsTensor);
TF_LITE_ENSURE(context, forget_layer_norm_coefficients != nullptr);
const TfLiteTensor* forget_layer_norm_coefficients;
TF_LITE_ENSURE_OK(
context, GetInputSafe(context, node,
lstm::full::kForgetLayerNormCoefficientsTensor,
&forget_layer_norm_coefficients));
TF_LITE_ENSURE_EQ(context, forget_layer_norm_coefficients->dims->size, 1);
TF_LITE_ENSURE_EQ(context, forget_layer_norm_coefficients->dims->data[0],
n_cell);
TF_LITE_ENSURE_TYPES_EQ(context, forget_layer_norm_coefficients->type,
kTfLiteFloat32);
const TfLiteTensor* cell_layer_norm_coefficients =
GetInput(context, node, lstm::full::kCellLayerNormCoefficientsTensor);
TF_LITE_ENSURE(context, cell_layer_norm_coefficients != nullptr);
const TfLiteTensor* cell_layer_norm_coefficients;
TF_LITE_ENSURE_OK(context,
GetInputSafe(context, node,
lstm::full::kCellLayerNormCoefficientsTensor,
&cell_layer_norm_coefficients));
TF_LITE_ENSURE_EQ(context, cell_layer_norm_coefficients->dims->size, 1);
TF_LITE_ENSURE_EQ(context, cell_layer_norm_coefficients->dims->data[0],
n_cell);
TF_LITE_ENSURE_TYPES_EQ(context, cell_layer_norm_coefficients->type,
kTfLiteFloat32);
const TfLiteTensor* output_layer_norm_coefficients =
GetInput(context, node, lstm::full::kOutputLayerNormCoefficientsTensor);
TF_LITE_ENSURE(context, output_layer_norm_coefficients != nullptr);
const TfLiteTensor* output_layer_norm_coefficients;
TF_LITE_ENSURE_OK(
context, GetInputSafe(context, node,
lstm::full::kOutputLayerNormCoefficientsTensor,
&output_layer_norm_coefficients));
TF_LITE_ENSURE_EQ(context, output_layer_norm_coefficients->dims->size, 1);
TF_LITE_ENSURE_EQ(context, output_layer_norm_coefficients->dims->data[0],
n_cell);
@ -291,7 +315,9 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
// Inferring batch size, number of outputs and sequence length and
// number of cells from the input tensors.
const TfLiteTensor* input = GetInput(context, node, lstm::full::kInputTensor);
const TfLiteTensor* input;
TF_LITE_ENSURE_OK(
context, GetInputSafe(context, node, lstm::full::kInputTensor, &input));
TF_LITE_ENSURE_TYPES_EQ(context, input->type, kTfLiteFloat32);
TF_LITE_ENSURE(context, input->dims->size > 1);
const auto* params =
@ -301,14 +327,20 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
const int n_batch = time_major ? input->dims->data[1] : input->dims->data[0];
const int n_input = input->dims->data[2];
const TfLiteTensor* input_to_output_weights =
GetInput(context, node, lstm::full::kInputToOutputWeightsTensor);
const TfLiteTensor* input_to_output_weights;
TF_LITE_ENSURE_OK(
context,
GetInputSafe(context, node, lstm::full::kInputToOutputWeightsTensor,
&input_to_output_weights));
const int n_cell = input_to_output_weights->dims->data[0];
TF_LITE_ENSURE_EQ(context, input_to_output_weights->dims->size, 2);
TF_LITE_ENSURE_EQ(context, input_to_output_weights->dims->data[1], n_input);
const TfLiteTensor* recurrent_to_output_weights =
GetInput(context, node, lstm::full::kRecurrentToOutputWeightsTensor);
const TfLiteTensor* recurrent_to_output_weights;
TF_LITE_ENSURE_OK(
context,
GetInputSafe(context, node, lstm::full::kRecurrentToOutputWeightsTensor,
&recurrent_to_output_weights));
TF_LITE_ENSURE_EQ(context, recurrent_to_output_weights->dims->size, 2);
TF_LITE_ENSURE_EQ(context, recurrent_to_output_weights->dims->data[0],
n_cell);
@ -320,7 +352,9 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
n_cell, is_layer_norm_lstm));
// Get the pointer to output, output_state and cell_state buffer tensors.
TfLiteTensor* output = GetOutput(context, node, lstm::full::kOutputTensor);
TfLiteTensor* output;
TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node,
lstm::full::kOutputTensor, &output));
TfLiteTensor* output_state =
GetVariableInput(context, node, lstm::full::kOutputStateTensor);
@ -351,7 +385,9 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
scratch_tensor_index + kScratchBuffer;
// Create a scratch buffer tensor.
TfLiteTensor* scratch_buffer = GetTemporary(context, node, kScratchBuffer);
TfLiteTensor* scratch_buffer;
TF_LITE_ENSURE_OK(context, GetTemporarySafe(context, node, kScratchBuffer,
&scratch_buffer));
scratch_buffer->type = input->type;
scratch_buffer->allocation_type = kTfLiteArenaRw;
@ -376,8 +412,9 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
// output_state and cell_state tensors.
node->temporaries->data[kInputQuantized] =
scratch_tensor_index + kInputQuantized;
TfLiteTensor* input_quantized =
GetTemporary(context, node, kInputQuantized);
TfLiteTensor* input_quantized;
TF_LITE_ENSURE_OK(context, GetTemporarySafe(context, node, kInputQuantized,
&input_quantized));
input_quantized->type = input_to_output_weights->type;
input_quantized->allocation_type = kTfLiteArenaRw;
if (!TfLiteIntArrayEqual(input_quantized->dims, input->dims)) {
@ -387,8 +424,10 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
}
node->temporaries->data[kOutputStateQuantized] =
scratch_tensor_index + kOutputStateQuantized;
TfLiteTensor* output_state_quantized =
GetTemporary(context, node, kOutputStateQuantized);
TfLiteTensor* output_state_quantized;
TF_LITE_ENSURE_OK(context,
GetTemporarySafe(context, node, kOutputStateQuantized,
&output_state_quantized));
output_state_quantized->type = input_to_output_weights->type;
output_state_quantized->allocation_type = kTfLiteArenaRw;
if (!TfLiteIntArrayEqual(output_state_quantized->dims,
@ -401,8 +440,10 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
}
node->temporaries->data[kCellStateQuantized] =
scratch_tensor_index + kCellStateQuantized;
TfLiteTensor* cell_state_quantized =
GetTemporary(context, node, kCellStateQuantized);
TfLiteTensor* cell_state_quantized;
TF_LITE_ENSURE_OK(context,
GetTemporarySafe(context, node, kCellStateQuantized,
&cell_state_quantized));
cell_state_quantized->type = input_to_output_weights->type;
cell_state_quantized->allocation_type = kTfLiteArenaRw;
if (!TfLiteIntArrayEqual(cell_state_quantized->dims, cell_state->dims)) {
@ -420,7 +461,10 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
// the scaling factor of the matrix).
node->temporaries->data[kInputScalingFactors] =
op_data->scratch_tensor_index + kInputScalingFactors;
TfLiteTensor* input_sf = GetTemporary(context, node, kInputScalingFactors);
TfLiteTensor* input_sf;
TF_LITE_ENSURE_OK(
context,
GetTemporarySafe(context, node, kInputScalingFactors, &input_sf));
input_sf->type = kTfLiteFloat32;
input_sf->allocation_type = kTfLiteArenaRw;
int scaling_dims[1] = {n_batch};
@ -432,8 +476,10 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
}
node->temporaries->data[kOutputStateScalingFactors] =
op_data->scratch_tensor_index + kOutputStateScalingFactors;
TfLiteTensor* output_state_sf =
GetTemporary(context, node, kOutputStateScalingFactors);
TfLiteTensor* output_state_sf;
TF_LITE_ENSURE_OK(
context, GetTemporarySafe(context, node, kOutputStateScalingFactors,
&output_state_sf));
output_state_sf->type = kTfLiteFloat32;
output_state_sf->allocation_type = kTfLiteArenaRw;
if (!TfLiteIntArrayEqualsArray(output_state_sf->dims, 1, scaling_dims)) {
@ -444,8 +490,10 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
}
node->temporaries->data[kProductScalingFactors] =
scratch_tensor_index + kProductScalingFactors;
TfLiteTensor* prod_scaling_factors =
GetTemporary(context, node, kProductScalingFactors);
TfLiteTensor* prod_scaling_factors;
TF_LITE_ENSURE_OK(context,
GetTemporarySafe(context, node, kProductScalingFactors,
&prod_scaling_factors));
prod_scaling_factors->type = kTfLiteFloat32;
prod_scaling_factors->allocation_type = kTfLiteArenaRw;
if (!TfLiteIntArrayEqualsArray(prod_scaling_factors->dims, 1,
@ -461,8 +509,10 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
// this is used for diagonal matrices, only need to store n_cell values.
node->temporaries->data[kRecoveredCellWeights] =
scratch_tensor_index + kRecoveredCellWeights;
TfLiteTensor* recovered_cell_weights =
GetTemporary(context, node, kRecoveredCellWeights);
TfLiteTensor* recovered_cell_weights;
TF_LITE_ENSURE_OK(context,
GetTemporarySafe(context, node, kRecoveredCellWeights,
&recovered_cell_weights));
recovered_cell_weights->type = kTfLiteFloat32;
recovered_cell_weights->allocation_type = kTfLiteArenaRw;
int recovered_cell_dims[1] = {n_cell};
@ -478,7 +528,9 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
// Allocate a temporary tensor to store the accumulated int32 values.
node->temporaries->data[kAccumScratch] =
scratch_tensor_index + kAccumScratch;
TfLiteTensor* accum_scratch = GetTemporary(context, node, kAccumScratch);
TfLiteTensor* accum_scratch;
TF_LITE_ENSURE_OK(context, GetTemporarySafe(context, node, kAccumScratch,
&accum_scratch));
accum_scratch->type = kTfLiteInt32;
accum_scratch->allocation_type = kTfLiteArenaRw;
int accum_scratch_dims[2] = {n_cell, n_batch};
@ -492,7 +544,9 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
}
node->temporaries->data[kInputZeroPoints] =
op_data->scratch_tensor_index + kInputZeroPoints;
TfLiteTensor* input_zp = GetTemporary(context, node, kInputZeroPoints);
TfLiteTensor* input_zp;
TF_LITE_ENSURE_OK(
context, GetTemporarySafe(context, node, kInputZeroPoints, &input_zp));
input_zp->type = kTfLiteFloat32;
input_zp->allocation_type = kTfLiteArenaRw;
if (!TfLiteIntArrayEqualsArray(input_zp->dims, 1, scaling_dims)) {
@ -503,8 +557,10 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
}
node->temporaries->data[kOutputStateZeroPoints] =
op_data->scratch_tensor_index + kOutputStateZeroPoints;
TfLiteTensor* output_state_zp =
GetTemporary(context, node, kOutputStateZeroPoints);
TfLiteTensor* output_state_zp;
TF_LITE_ENSURE_OK(context,
GetTemporarySafe(context, node, kOutputStateZeroPoints,
&output_state_zp));
output_state_zp->type = kTfLiteFloat32;
output_state_zp->allocation_type = kTfLiteArenaRw;
if (!TfLiteIntArrayEqualsArray(output_state_zp->dims, 1, scaling_dims)) {
@ -514,7 +570,9 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
output_state_zp_size));
}
node->temporaries->data[kRowSums] = scratch_tensor_index + kRowSums;
TfLiteTensor* row_sums = GetTemporary(context, node, kRowSums);
TfLiteTensor* row_sums;
TF_LITE_ENSURE_OK(context,
GetTemporarySafe(context, node, kRowSums, &row_sums));
row_sums->type = kTfLiteInt32;
row_sums->allocation_type = kTfLiteArenaRwPersistent;
int row_sums_rows = use_cifg ? 6 : 8;
@ -542,25 +600,44 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
const OpData* op_data = reinterpret_cast<OpData*>(node->user_data);
const bool is_layer_norm_lstm = op_data->is_layer_norm_lstm;
const bool time_major = params->time_major;
const TfLiteTensor* input = GetInput(context, node, lstm::full::kInputTensor);
const TfLiteTensor* input;
TF_LITE_ENSURE_OK(
context, GetInputSafe(context, node, lstm::full::kInputTensor, &input));
const TfLiteTensor* input_to_input_weights = GetOptionalInputTensor(
context, node, lstm::full::kInputToInputWeightsTensor);
const TfLiteTensor* input_to_forget_weights =
GetInput(context, node, lstm::full::kInputToForgetWeightsTensor);
const TfLiteTensor* input_to_cell_weights =
GetInput(context, node, lstm::full::kInputToCellWeightsTensor);
const TfLiteTensor* input_to_output_weights =
GetInput(context, node, lstm::full::kInputToOutputWeightsTensor);
const TfLiteTensor* input_to_forget_weights;
TF_LITE_ENSURE_OK(
context,
GetInputSafe(context, node, lstm::full::kInputToForgetWeightsTensor,
&input_to_forget_weights));
const TfLiteTensor* input_to_cell_weights;
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node,
lstm::full::kInputToCellWeightsTensor,
&input_to_cell_weights));
const TfLiteTensor* input_to_output_weights;
TF_LITE_ENSURE_OK(
context,
GetInputSafe(context, node, lstm::full::kInputToOutputWeightsTensor,
&input_to_output_weights));
const TfLiteTensor* recurrent_to_input_weights = GetOptionalInputTensor(
context, node, lstm::full::kRecurrentToInputWeightsTensor);
const TfLiteTensor* recurrent_to_forget_weights =
GetInput(context, node, lstm::full::kRecurrentToForgetWeightsTensor);
const TfLiteTensor* recurrent_to_cell_weights =
GetInput(context, node, lstm::full::kRecurrentToCellWeightsTensor);
const TfLiteTensor* recurrent_to_output_weights =
GetInput(context, node, lstm::full::kRecurrentToOutputWeightsTensor);
const TfLiteTensor* recurrent_to_forget_weights;
TF_LITE_ENSURE_OK(
context,
GetInputSafe(context, node, lstm::full::kRecurrentToForgetWeightsTensor,
&recurrent_to_forget_weights));
const TfLiteTensor* recurrent_to_cell_weights;
TF_LITE_ENSURE_OK(
context,
GetInputSafe(context, node, lstm::full::kRecurrentToCellWeightsTensor,
&recurrent_to_cell_weights));
const TfLiteTensor* recurrent_to_output_weights;
TF_LITE_ENSURE_OK(
context,
GetInputSafe(context, node, lstm::full::kRecurrentToOutputWeightsTensor,
&recurrent_to_output_weights));
const TfLiteTensor* cell_to_input_weights = GetOptionalInputTensor(
context, node, lstm::full::kCellToInputWeightsTensor);
@ -571,12 +648,18 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
const TfLiteTensor* input_gate_bias =
GetOptionalInputTensor(context, node, lstm::full::kInputGateBiasTensor);
const TfLiteTensor* forget_gate_bias =
GetInput(context, node, lstm::full::kForgetGateBiasTensor);
const TfLiteTensor* cell_gate_bias =
GetInput(context, node, lstm::full::kCellGateBiasTensor);
const TfLiteTensor* output_gate_bias =
GetInput(context, node, lstm::full::kOutputGateBiasTensor);
const TfLiteTensor* forget_gate_bias;
TF_LITE_ENSURE_OK(
context, GetInputSafe(context, node, lstm::full::kForgetGateBiasTensor,
&forget_gate_bias));
const TfLiteTensor* cell_gate_bias;
TF_LITE_ENSURE_OK(context,
GetInputSafe(context, node, lstm::full::kCellGateBiasTensor,
&cell_gate_bias));
const TfLiteTensor* output_gate_bias;
TF_LITE_ENSURE_OK(
context, GetInputSafe(context, node, lstm::full::kOutputGateBiasTensor,
&output_gate_bias));
const TfLiteTensor* projection_weights = GetOptionalInputTensor(
context, node, lstm::full::kProjectionWeightsTensor);
@ -584,14 +667,16 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
GetOptionalInputTensor(context, node, lstm::full::kProjectionBiasTensor);
// Index the scratch buffers pointers to the global scratch buffer.
TfLiteTensor* scratch_buffer = GetTemporary(context, node, kScratchBuffer);
TfLiteTensor* scratch_buffer;
TF_LITE_ENSURE_OK(context, GetTemporarySafe(context, node, kScratchBuffer,
&scratch_buffer));
TfLiteTensor* output_state =
GetVariableInput(context, node, lstm::full::kOutputStateTensor);
TF_LITE_ENSURE(context, output_state != nullptr);
TFLITE_DCHECK(output_state != nullptr);
TfLiteTensor* cell_state =
GetVariableInput(context, node, lstm::full::kCellStateTensor);
TF_LITE_ENSURE(context, cell_state != nullptr);
TFLITE_DCHECK(cell_state != nullptr);
const TfLiteTensor* input_layer_norm_coefficients =
is_layer_norm_lstm
@ -614,7 +699,9 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
lstm::full::kOutputLayerNormCoefficientsTensor)
: nullptr;
TfLiteTensor* output = GetOutput(context, node, lstm::full::kOutputTensor);
TfLiteTensor* output;
TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node,
lstm::full::kOutputTensor, &output));
// Copy out the LSTM specific params so they can be passed in the function.
TfLiteLSTMParams lstm_params;
@ -647,7 +734,9 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
case kTfLiteUInt8:
case kTfLiteInt8: {
OpData* op_data = reinterpret_cast<OpData*>(node->user_data);
TfLiteTensor* row_sums = GetTemporary(context, node, kRowSums);
TfLiteTensor* row_sums;
TF_LITE_ENSURE_OK(context,
GetTemporarySafe(context, node, kRowSums, &row_sums));
const int row_sums_size = row_sums->dims->data[0];
return lstm_eval::EvalHybrid(
input, input_to_input_weights,

View File

@ -61,13 +61,20 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_ENSURE_EQ(context, node->inputs->size, 5);
TF_LITE_ENSURE_EQ(context, node->outputs->size, 1);
const TfLiteTensor* input = GetInput(context, node, kInputTensor);
const TfLiteTensor* input_weights = GetInput(context, node, kWeightsTensor);
const TfLiteTensor* recurrent_weights =
GetInput(context, node, kRecurrentWeightsTensor);
const TfLiteTensor* bias = GetInput(context, node, kBiasTensor);
const TfLiteTensor* hidden_state =
GetInput(context, node, kHiddenStateTensor);
const TfLiteTensor* input;
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputTensor, &input));
const TfLiteTensor* input_weights;
TF_LITE_ENSURE_OK(
context, GetInputSafe(context, node, kWeightsTensor, &input_weights));
const TfLiteTensor* recurrent_weights;
TF_LITE_ENSURE_OK(
context,
GetInputSafe(context, node, kRecurrentWeightsTensor, &recurrent_weights));
const TfLiteTensor* bias;
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kBiasTensor, &bias));
const TfLiteTensor* hidden_state;
TF_LITE_ENSURE_OK(
context, GetInputSafe(context, node, kHiddenStateTensor, &hidden_state));
// Check all the parameters of tensor match within themselves and match the
// input configuration.
@ -92,7 +99,9 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_ENSURE_EQ(context, hidden_state->dims->data[0], batch_size);
TF_LITE_ENSURE_EQ(context, hidden_state->dims->data[1], num_units);
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
TfLiteTensor* output;
TF_LITE_ENSURE_OK(context,
GetOutputSafe(context, node, kOutputTensor, &output));
// Resize output.
TfLiteIntArray* output_size_array = TfLiteIntArrayCreate(3);
@ -112,7 +121,9 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
TfLiteIntArrayFree(node->temporaries);
node->temporaries = TfLiteIntArrayCreate(6);
node->temporaries->data[0] = op_data->scratch_tensor_index;
TfLiteTensor* input_quantized = GetTemporary(context, node, /*index=*/0);
TfLiteTensor* input_quantized;
TF_LITE_ENSURE_OK(context, GetTemporarySafe(context, node, /*index=*/0,
&input_quantized));
input_quantized->type = input_weights->type;
input_quantized->allocation_type = kTfLiteArenaRw;
if (!TfLiteIntArrayEqual(input_quantized->dims, input->dims)) {
@ -121,8 +132,9 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
input_quantized_size));
}
node->temporaries->data[1] = op_data->scratch_tensor_index + 1;
TfLiteTensor* hidden_state_quantized =
GetTemporary(context, node, /*index=*/1);
TfLiteTensor* hidden_state_quantized;
TF_LITE_ENSURE_OK(context, GetTemporarySafe(context, node, /*index=*/1,
&hidden_state_quantized));
hidden_state_quantized->type = input_weights->type;
hidden_state_quantized->allocation_type = kTfLiteArenaRw;
if (!TfLiteIntArrayEqual(hidden_state_quantized->dims,
@ -134,7 +146,9 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
hidden_state_quantized_size));
}
node->temporaries->data[2] = op_data->scratch_tensor_index + 2;
TfLiteTensor* scaling_factors = GetTemporary(context, node, /*index=*/2);
TfLiteTensor* scaling_factors;
TF_LITE_ENSURE_OK(context, GetTemporarySafe(context, node, /*index=*/2,
&scaling_factors));
scaling_factors->type = kTfLiteFloat32;
scaling_factors->allocation_type = kTfLiteArenaRw;
int scaling_dims[1] = {batch_size};
@ -145,7 +159,9 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
scaling_factors_size));
}
node->temporaries->data[3] = op_data->scratch_tensor_index + 3;
TfLiteTensor* accum_scratch = GetTemporary(context, node, /*index=*/3);
TfLiteTensor* accum_scratch;
TF_LITE_ENSURE_OK(
context, GetTemporarySafe(context, node, /*index=*/3, &accum_scratch));
accum_scratch->type = kTfLiteInt32;
accum_scratch->allocation_type = kTfLiteArenaRw;
int accum_scratch_dims[2] = {num_units, batch_size};
@ -158,7 +174,9 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
accum_scratch_size));
}
node->temporaries->data[4] = op_data->scratch_tensor_index + 4;
TfLiteTensor* zero_points = GetTemporary(context, node, /*index=*/4);
TfLiteTensor* zero_points;
TF_LITE_ENSURE_OK(
context, GetTemporarySafe(context, node, /*index=*/4, &zero_points));
zero_points->type = kTfLiteInt32;
zero_points->allocation_type = kTfLiteArenaRw;
int zero_points_dims[1] = {batch_size};
@ -169,7 +187,9 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
zero_points_size));
}
node->temporaries->data[5] = op_data->scratch_tensor_index + 5;
TfLiteTensor* row_sums = GetTemporary(context, node, /*index=*/5);
TfLiteTensor* row_sums;
TF_LITE_ENSURE_OK(context,
GetTemporarySafe(context, node, /*index=*/5, &row_sums));
row_sums->type = kTfLiteInt32;
row_sums->allocation_type = kTfLiteArenaRwPersistent;
int row_sums_dims[2] = {2, num_units};
@ -335,15 +355,24 @@ TfLiteStatus EvalHybrid(
TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
auto* params = reinterpret_cast<TfLiteSequenceRNNParams*>(node->builtin_data);
const TfLiteTensor* input = GetInput(context, node, kInputTensor);
const TfLiteTensor* input_weights = GetInput(context, node, kWeightsTensor);
const TfLiteTensor* recurrent_weights =
GetInput(context, node, kRecurrentWeightsTensor);
const TfLiteTensor* bias = GetInput(context, node, kBiasTensor);
const TfLiteTensor* input;
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputTensor, &input));
const TfLiteTensor* input_weights;
TF_LITE_ENSURE_OK(
context, GetInputSafe(context, node, kWeightsTensor, &input_weights));
const TfLiteTensor* recurrent_weights;
TF_LITE_ENSURE_OK(
context,
GetInputSafe(context, node, kRecurrentWeightsTensor, &recurrent_weights));
const TfLiteTensor* bias;
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kBiasTensor, &bias));
// The hidden_state is a variable input tensor that can be modified.
TfLiteTensor* hidden_state =
const_cast<TfLiteTensor*>(GetInput(context, node, kHiddenStateTensor));
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
GetVariableInput(context, node, kHiddenStateTensor);
TF_LITE_ENSURE(context, hidden_state != nullptr);
TfLiteTensor* output;
TF_LITE_ENSURE_OK(context,
GetOutputSafe(context, node, kOutputTensor, &output));
switch (input_weights->type) {
case kTfLiteFloat32:
@ -353,12 +382,23 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
case kTfLiteInt8: {
// TODO(mirkov): implement eval with quantized inputs as well.
auto* op_data = reinterpret_cast<OpData*>(node->user_data);
TfLiteTensor* input_quantized = GetTemporary(context, node, 0);
TfLiteTensor* hidden_state_quantized = GetTemporary(context, node, 1);
TfLiteTensor* scaling_factors = GetTemporary(context, node, 2);
TfLiteTensor* accum_scratch = GetTemporary(context, node, 3);
TfLiteTensor* zero_points = GetTemporary(context, node, 4);
TfLiteTensor* row_sums = GetTemporary(context, node, 5);
TfLiteTensor* input_quantized;
TF_LITE_ENSURE_OK(context,
GetTemporarySafe(context, node, 0, &input_quantized));
TfLiteTensor* hidden_state_quantized;
TF_LITE_ENSURE_OK(
context, GetTemporarySafe(context, node, 1, &hidden_state_quantized));
TfLiteTensor* scaling_factors;
TF_LITE_ENSURE_OK(context,
GetTemporarySafe(context, node, 2, &scaling_factors));
TfLiteTensor* accum_scratch;
TF_LITE_ENSURE_OK(context,
GetTemporarySafe(context, node, 3, &accum_scratch));
TfLiteTensor* zero_points;
TF_LITE_ENSURE_OK(context,
GetTemporarySafe(context, node, 4, &zero_points));
TfLiteTensor* row_sums;
TF_LITE_ENSURE_OK(context, GetTemporarySafe(context, node, 5, &row_sums));
return EvalHybrid(input, input_weights, recurrent_weights, bias, params,
input_quantized, hidden_state_quantized,
scaling_factors, hidden_state, output, zero_points,

View File

@ -44,11 +44,14 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_ENSURE_EQ(context, NumInputs(node), 1);
TF_LITE_ENSURE_EQ(context, NumOutputs(node), 2);
const TfLiteTensor* input = GetInput(context, node, 0);
TfLiteTensor* output_unique_tensor =
GetOutput(context, node, kOutputUniqueTensor);
TfLiteTensor* output_index_tensor =
GetOutput(context, node, kOutputIndexTensor);
const TfLiteTensor* input;
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 0, &input));
TfLiteTensor* output_unique_tensor;
TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, kOutputUniqueTensor,
&output_unique_tensor));
TfLiteTensor* output_index_tensor;
TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, kOutputIndexTensor,
&output_index_tensor));
// The op only supports 1D input.
TF_LITE_ENSURE_EQ(context, NumDimensions(input), 1);
@ -70,7 +73,8 @@ TfLiteStatus EvalImpl(TfLiteContext* context, const TfLiteTensor* input,
// Note that we prefer to use map than unordered_map as it showed less
// increase in the binary size.
std::map<T, int> unique_values;
TfLiteTensor* output_indexes = GetOutput(context, node, 1);
TfLiteTensor* output_indexes;
TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, 1, &output_indexes));
std::vector<T> output_values;
I* indexes = GetTensorData<I>(output_indexes);
const T* data = GetTensorData<T>(input);
@ -88,7 +92,8 @@ TfLiteStatus EvalImpl(TfLiteContext* context, const TfLiteTensor* input,
}
}
// Allocate output tensor.
TfLiteTensor* unique_output = GetOutput(context, node, 0);
TfLiteTensor* unique_output;
TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, 0, &unique_output));
std::unique_ptr<TfLiteIntArray, void (*)(TfLiteIntArray*)> shape(
TfLiteIntArrayCreate(NumDimensions(input)), TfLiteIntArrayFree);
shape->data[0] = unique_values.size();
@ -127,8 +132,11 @@ TfLiteStatus EvalImpl(TfLiteContext* context, const TfLiteTensor* input,
} // namespace
TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
const TfLiteTensor* input = GetInput(context, node, 0);
TfLiteTensor* output_index_tensor = GetOutput(context, node, 1);
const TfLiteTensor* input;
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 0, &input));
TfLiteTensor* output_index_tensor;
TF_LITE_ENSURE_OK(context,
GetOutputSafe(context, node, 1, &output_index_tensor));
TF_LITE_ENSURE_EQ(context, NumElements(output_index_tensor),
NumElements(input));

View File

@ -38,7 +38,8 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_ENSURE_EQ(context, NumInputs(node), 1);
TF_LITE_ENSURE_EQ(context, NumOutputs(node), data->num);
const TfLiteTensor* input = GetInput(context, node, kInputTensor);
const TfLiteTensor* input;
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputTensor, &input));
TF_LITE_ENSURE(context, NumElements(input) > 0);
int axis = data->axis;
if (axis < 0) {
@ -67,7 +68,8 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_ENSURE_EQ(context, data->num, input_shape->data[axis]);
for (int i = 0; i < data->num; ++i) {
TfLiteIntArray* copied_output_shape = TfLiteIntArrayCopy(output_shape);
TfLiteTensor* output = GetOutput(context, node, i);
TfLiteTensor* output;
TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, i, &output));
TF_LITE_ENSURE_TYPES_EQ(context, output->type, input->type);
// Guarantee input/output quantization params match as we do not support
// rescaling of unpacked quantized tensors.
@ -98,7 +100,8 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
const TfLiteUnpackParams* data =
reinterpret_cast<TfLiteUnpackParams*>(node->builtin_data);
const TfLiteTensor* input = GetInput(context, node, kInputTensor);
const TfLiteTensor* input;
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputTensor, &input));
switch (input->type) {
case kTfLiteFloat32: {
UnpackImpl<float>(context, node, input, data->num, data->axis);

View File

@ -56,9 +56,12 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_ENSURE_EQ(context, NumInputs(node), 1);
TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
const TfLiteTensor* cond_tensor =
GetInput(context, node, kInputConditionTensor);
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
const TfLiteTensor* cond_tensor;
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputConditionTensor,
&cond_tensor));
TfLiteTensor* output;
TF_LITE_ENSURE_OK(context,
GetOutputSafe(context, node, kOutputTensor, &output));
if (cond_tensor->type != kTfLiteBool) {
context->ReportError(context,
@ -81,9 +84,12 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
}
TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
const TfLiteTensor* cond_tensor =
GetInput(context, node, kInputConditionTensor);
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
const TfLiteTensor* cond_tensor;
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputConditionTensor,
&cond_tensor));
TfLiteTensor* output;
TF_LITE_ENSURE_OK(context,
GetOutputSafe(context, node, kOutputTensor, &output));
if (IsDynamicTensor(output)) {
TF_LITE_ENSURE_OK(context,

View File

@ -195,7 +195,8 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
}
}
for (int i = 0; i < num_inputs; ++i) {
TfLiteTensor* output = GetOutput(context, node, i);
TfLiteTensor* output;
TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, i, &output));
if (op_data->body_has_dynamic_output_tensors) {
SetTensorToDynamic(output);
} else {

View File

@ -32,8 +32,11 @@ constexpr int kOutputTensor = 0;
TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_ENSURE_EQ(context, NumInputs(node), 1);
TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
const TfLiteTensor* input = GetInput(context, node, kInputTensor);
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
const TfLiteTensor* input;
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputTensor, &input));
TfLiteTensor* output;
TF_LITE_ENSURE_OK(context,
GetOutputSafe(context, node, kOutputTensor, &output));
output->type = input->type;
return context->ResizeTensor(context, output,
@ -41,8 +44,11 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
}
TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
const TfLiteTensor* input = GetInput(context, node, kInputTensor);
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
const TfLiteTensor* input;
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputTensor, &input));
TfLiteTensor* output;
TF_LITE_ENSURE_OK(context,
GetOutputSafe(context, node, kOutputTensor, &output));
const int num_elements = NumElements(input);
switch (input->type) {
case kTfLiteInt64: