[tflite]: Insert nullptr
checks when obtaining tensors.
As part of ongoing refactoring, `tflite::GetInput`, `tflite::GetOutput`, `tflite::GetTemporary` and `tflite::GetIntermediates` will return `nullptr` in some cases. Hence, we insert the `nullptr` checks on all usages. We also insert `nullptr` checks on usages of `tflite::GetVariableInput` and `tflite::GetOptionalInputTensor` but only in the cases where there is no obvious check that `nullptr` is acceptable (that is, we only insert the check for the output of these two functions if the tensor is accessed as if it is always not `nullptr`). PiperOrigin-RevId: 332517854 Change-Id: Ic27221dd1f0fbe302f311c2fe5a846ed8ff02016
This commit is contained in:
parent
ec98fee0c3
commit
e11f55585f
@ -42,9 +42,12 @@ TfLiteRegistration AddOpRegistration() {
|
||||
|
||||
reg.prepare = [](TfLiteContext* context, TfLiteNode* node) {
|
||||
// Set output size to input size
|
||||
const TfLiteTensor* input1 = GetInput(context, node, 0);
|
||||
const TfLiteTensor* input2 = GetInput(context, node, 1);
|
||||
TfLiteTensor* output = GetOutput(context, node, 0);
|
||||
const TfLiteTensor* input1;
|
||||
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 0, &input1));
|
||||
const TfLiteTensor* input2;
|
||||
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 1, &input2));
|
||||
TfLiteTensor* output;
|
||||
TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, 0, &output));
|
||||
|
||||
TF_LITE_ENSURE_EQ(context, input1->dims->size, input2->dims->size);
|
||||
for (int i = 0; i < input1->dims->size; ++i) {
|
||||
@ -58,13 +61,16 @@ TfLiteRegistration AddOpRegistration() {
|
||||
|
||||
reg.invoke = [](TfLiteContext* context, TfLiteNode* node) {
|
||||
// Copy input data to output data.
|
||||
const TfLiteTensor* a0 = GetInput(context, node, 0);
|
||||
const TfLiteTensor* a0;
|
||||
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 0, &a0));
|
||||
TF_LITE_ENSURE(context, a0);
|
||||
TF_LITE_ENSURE(context, a0->data.f);
|
||||
const TfLiteTensor* a1 = GetInput(context, node, 1);
|
||||
const TfLiteTensor* a1;
|
||||
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 1, &a1));
|
||||
TF_LITE_ENSURE(context, a1);
|
||||
TF_LITE_ENSURE(context, a1->data.f);
|
||||
TfLiteTensor* out = GetOutput(context, node, 0);
|
||||
TfLiteTensor* out;
|
||||
TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, 0, &out));
|
||||
TF_LITE_ENSURE(context, out);
|
||||
TF_LITE_ENSURE(context, out->data.f);
|
||||
int num = a0->dims->data[0];
|
||||
@ -267,7 +273,8 @@ class TestDelegate : public ::testing::Test {
|
||||
a0 = GetInput(context, node, 0);
|
||||
a1 = a0;
|
||||
}
|
||||
TfLiteTensor* out = GetOutput(context, node, 0);
|
||||
TfLiteTensor* out;
|
||||
TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, 0, &out));
|
||||
int num = 1;
|
||||
for (int i = 0; i < a0->dims->size; ++i) {
|
||||
num *= a0->dims->data[i];
|
||||
@ -289,8 +296,10 @@ class TestDelegate : public ::testing::Test {
|
||||
reg.prepare = [](TfLiteContext* context, TfLiteNode* node) {
|
||||
// Shapes should already by propagated by the runtime, just need to
|
||||
// check.
|
||||
const TfLiteTensor* input1 = GetInput(context, node, 0);
|
||||
TfLiteTensor* output = GetOutput(context, node, 0);
|
||||
const TfLiteTensor* input1;
|
||||
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 0, &input1));
|
||||
TfLiteTensor* output;
|
||||
TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, 0, &output));
|
||||
const int input_dims_size = input1->dims->size;
|
||||
TF_LITE_ENSURE(context, output->dims->size == input_dims_size);
|
||||
for (int i = 0; i < input_dims_size; ++i) {
|
||||
@ -315,7 +324,8 @@ class TestDelegate : public ::testing::Test {
|
||||
input1 = GetInput(context, node, 0);
|
||||
input2 = input1;
|
||||
}
|
||||
TfLiteTensor* output = GetOutput(context, node, 0);
|
||||
TfLiteTensor* output;
|
||||
TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, 0, &output));
|
||||
|
||||
TF_LITE_ENSURE_STATUS(context->ResizeTensor(
|
||||
context, output, TfLiteIntArrayCopy(input1->dims)));
|
||||
@ -1169,11 +1179,14 @@ class TestDelegateWithDynamicTensors : public ::testing::Test {
|
||||
|
||||
reg.prepare = [](TfLiteContext* context, TfLiteNode* node) {
|
||||
// Output 0 is dynamic
|
||||
TfLiteTensor* output0 = GetOutput(context, node, 0);
|
||||
TfLiteTensor* output0;
|
||||
TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, 0, &output0));
|
||||
SetTensorToDynamic(output0);
|
||||
// Output 1 has the same shape as input.
|
||||
const TfLiteTensor* input = GetInput(context, node, 0);
|
||||
TfLiteTensor* output1 = GetOutput(context, node, 1);
|
||||
const TfLiteTensor* input;
|
||||
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 0, &input));
|
||||
TfLiteTensor* output1;
|
||||
TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, 1, &output1));
|
||||
TF_LITE_ENSURE_STATUS(context->ResizeTensor(
|
||||
context, output1, TfLiteIntArrayCopy(input->dims)));
|
||||
return kTfLiteOk;
|
||||
@ -1193,11 +1206,14 @@ class TestDelegateWithDynamicTensors : public ::testing::Test {
|
||||
// If tensors are resized, the runtime should propagate shapes
|
||||
// automatically if correct flag is set. Ensure values are correct.
|
||||
// Output 0 should be dynamic.
|
||||
TfLiteTensor* output0 = GetOutput(context, node, 0);
|
||||
TfLiteTensor* output0;
|
||||
TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, 0, &output0));
|
||||
TF_LITE_ENSURE(context, IsDynamicTensor(output0));
|
||||
// Output 1 has the same shape as input.
|
||||
const TfLiteTensor* input = GetInput(context, node, 0);
|
||||
TfLiteTensor* output1 = GetOutput(context, node, 1);
|
||||
const TfLiteTensor* input;
|
||||
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 0, &input));
|
||||
TfLiteTensor* output1;
|
||||
TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, 1, &output1));
|
||||
TF_LITE_ENSURE(context, input->dims->size == output1->dims->size);
|
||||
TF_LITE_ENSURE(context, input->dims->data[0] == output1->dims->data[0]);
|
||||
return kTfLiteOk;
|
||||
|
@ -1166,6 +1166,9 @@ class MulOperationParser : public TFLiteOperationParser {
|
||||
}
|
||||
auto input0 = tflite::GetInput(context, tflite_node, 0);
|
||||
auto input1 = tflite::GetInput(context, tflite_node, 1);
|
||||
if (input0 == nullptr || input1 == nullptr) {
|
||||
return absl::InvalidArgumentError("At least one input tensor is null");
|
||||
}
|
||||
if (input0->dims->size == input1->dims->size) {
|
||||
// this code checks that at least one input of Mul not smaller in all
|
||||
// dimensions. Sometimes Mul used for matrix-vector multiplication that we
|
||||
@ -1380,7 +1383,10 @@ class PadOperationParser : public TFLiteOperationParser {
|
||||
RETURN_IF_ERROR(CheckInputsOutputs(context, tflite_node,
|
||||
/*runtime_inputs=*/1, /*outputs=*/1));
|
||||
RETURN_IF_ERROR(CheckTensorIsAvailable(context, tflite_node, 1));
|
||||
auto pad_tensor = tflite::GetInput(context, tflite_node, 1);
|
||||
const TfLiteTensor* pad_tensor = tflite::GetInput(context, tflite_node, 1);
|
||||
if (pad_tensor == nullptr) {
|
||||
return absl::InvalidArgumentError("Padding tensor was null");
|
||||
}
|
||||
if (pad_tensor->dims->size != 2) {
|
||||
return absl::InvalidArgumentError(absl::StrCat(
|
||||
"Invalid paddings tensor dimension: expected 2 dim, got ",
|
||||
|
@ -328,7 +328,9 @@ bool IsConvolutionOpSupported(const TfLiteRegistration* registration,
|
||||
const int kOutputShapeTensor = 0; // Only used for TransposeConv
|
||||
const int kWeightTensor = 1;
|
||||
const int kBiasTensor = 2; // Only used for non-TransposeConv
|
||||
const TfLiteTensor* weights = GetInput(context, node, kWeightTensor);
|
||||
const TfLiteTensor* weights;
|
||||
TF_LITE_ENSURE_OK(context,
|
||||
GetInputSafe(context, node, kWeightTensor, &weights));
|
||||
const int max_kernel_size = 16384;
|
||||
if (!IsConstantTensor(weights)) {
|
||||
return false;
|
||||
|
@ -153,8 +153,10 @@ bool IsFullyConnectedOpSupported(const TfLiteRegistration* registration,
|
||||
if (fc_params->weights_format != kTfLiteFullyConnectedWeightsFormatDefault) {
|
||||
return false;
|
||||
}
|
||||
const TfLiteTensor* input = GetInput(context, node, kInput);
|
||||
const TfLiteTensor* weights = GetInput(context, node, kWeights);
|
||||
const TfLiteTensor* input;
|
||||
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInput, &input));
|
||||
const TfLiteTensor* weights;
|
||||
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kWeights, &weights));
|
||||
|
||||
if (!IsFloatType(input->type)) {
|
||||
return false;
|
||||
@ -169,7 +171,8 @@ bool IsFullyConnectedOpSupported(const TfLiteRegistration* registration,
|
||||
}
|
||||
|
||||
if (node->inputs->size > 2) {
|
||||
const TfLiteTensor* bias = GetInput(context, node, kBias);
|
||||
const TfLiteTensor* bias;
|
||||
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kBias, &bias));
|
||||
if (!IsFloatType(bias->type) || !IsConstantTensor(bias)) {
|
||||
return false;
|
||||
}
|
||||
|
@ -97,7 +97,8 @@ OpBuilder* CreateMirrorPadOpBuilder(GraphBuilder* graph_builder) {
|
||||
bool IsPadOpSupported(const TfLiteRegistration* registration,
|
||||
const TfLiteNode* node, TfLiteContext* context) {
|
||||
// padding is d x 2 tensor, where d is the dimension of input.
|
||||
const TfLiteTensor* padding = GetInput(context, node, 1);
|
||||
const TfLiteTensor* padding;
|
||||
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 1, &padding));
|
||||
if (!IsConstantTensor(padding)) {
|
||||
TF_LITE_KERNEL_LOG(context,
|
||||
"%s: Only constant padding is supported for PAD.",
|
||||
|
@ -126,7 +126,8 @@ bool IsReshapeOpSupported(const TfLiteRegistration* registration,
|
||||
}
|
||||
|
||||
const int kShapeTensor = 1;
|
||||
const auto* shape = GetInput(context, node, kShapeTensor);
|
||||
const TfLiteTensor* shape;
|
||||
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kShapeTensor, &shape));
|
||||
if (shape->allocation_type != kTfLiteMmapRo) {
|
||||
TF_LITE_KERNEL_LOG(context, "Reshape has non-const shape.");
|
||||
return false;
|
||||
|
@ -62,14 +62,17 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
// The outputs should be top_paths * 3 + 1.
|
||||
TF_LITE_ENSURE_EQ(context, NumOutputs(node), 3 * top_paths + 1);
|
||||
|
||||
const TfLiteTensor* inputs = GetInput(context, node, kInputsTensor);
|
||||
const TfLiteTensor* inputs;
|
||||
TF_LITE_ENSURE_OK(context,
|
||||
GetInputSafe(context, node, kInputsTensor, &inputs));
|
||||
TF_LITE_ENSURE_EQ(context, NumDimensions(inputs), 3);
|
||||
// TensorFlow only supports float.
|
||||
TF_LITE_ENSURE_EQ(context, inputs->type, kTfLiteFloat32);
|
||||
const int batch_size = SizeOfDimension(inputs, 1);
|
||||
|
||||
const TfLiteTensor* sequence_length =
|
||||
GetInput(context, node, kSequenceLengthTensor);
|
||||
const TfLiteTensor* sequence_length;
|
||||
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kSequenceLengthTensor,
|
||||
&sequence_length));
|
||||
TF_LITE_ENSURE_EQ(context, NumDimensions(sequence_length), 1);
|
||||
TF_LITE_ENSURE_EQ(context, NumElements(sequence_length), batch_size);
|
||||
// TensorFlow only supports int32.
|
||||
@ -78,17 +81,23 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
// Resize decoded outputs.
|
||||
// Do not resize indices & values cause we don't know the values yet.
|
||||
for (int i = 0; i < top_paths; ++i) {
|
||||
TfLiteTensor* indices = GetOutput(context, node, i);
|
||||
TfLiteTensor* indices;
|
||||
TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, i, &indices));
|
||||
SetTensorToDynamic(indices);
|
||||
TfLiteTensor* values = GetOutput(context, node, i + top_paths);
|
||||
TfLiteTensor* values;
|
||||
TF_LITE_ENSURE_OK(context,
|
||||
GetOutputSafe(context, node, i + top_paths, &values));
|
||||
SetTensorToDynamic(values);
|
||||
TfLiteTensor* output_shape = GetOutput(context, node, i + 2 * top_paths);
|
||||
TfLiteTensor* output_shape;
|
||||
TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, i + 2 * top_paths,
|
||||
&output_shape));
|
||||
SetTensorToDynamic(output_shape);
|
||||
}
|
||||
|
||||
// Resize log probability outputs.
|
||||
TfLiteTensor* log_probability_output =
|
||||
GetOutput(context, node, top_paths * 3);
|
||||
TfLiteTensor* log_probability_output;
|
||||
TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, top_paths * 3,
|
||||
&log_probability_output));
|
||||
TfLiteIntArray* log_probability_output_shape_array = TfLiteIntArrayCreate(2);
|
||||
log_probability_output_shape_array->data[0] = batch_size;
|
||||
log_probability_output_shape_array->data[1] = top_paths;
|
||||
@ -127,13 +136,18 @@ TfLiteStatus StoreAllDecodedSequences(
|
||||
const int32_t p_num = num_entries[p];
|
||||
|
||||
// Resize the decoded outputs.
|
||||
TfLiteTensor* indices = GetOutput(context, node, p);
|
||||
TfLiteTensor* indices;
|
||||
TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, p, &indices));
|
||||
TF_LITE_ENSURE_OK(context, Resize(context, {p_num, 2}, indices));
|
||||
|
||||
TfLiteTensor* values = GetOutput(context, node, p + top_paths);
|
||||
TfLiteTensor* values;
|
||||
TF_LITE_ENSURE_OK(context,
|
||||
GetOutputSafe(context, node, p + top_paths, &values));
|
||||
TF_LITE_ENSURE_OK(context, Resize(context, {p_num}, values));
|
||||
|
||||
TfLiteTensor* decoded_shape = GetOutput(context, node, p + 2 * top_paths);
|
||||
TfLiteTensor* decoded_shape;
|
||||
TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, p + 2 * top_paths,
|
||||
&decoded_shape));
|
||||
TF_LITE_ENSURE_OK(context, Resize(context, {2}, decoded_shape));
|
||||
|
||||
int32_t max_decoded = 0;
|
||||
@ -161,9 +175,12 @@ TfLiteStatus StoreAllDecodedSequences(
|
||||
}
|
||||
|
||||
TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
||||
const TfLiteTensor* inputs = GetInput(context, node, kInputsTensor);
|
||||
const TfLiteTensor* sequence_length =
|
||||
GetInput(context, node, kSequenceLengthTensor);
|
||||
const TfLiteTensor* inputs;
|
||||
TF_LITE_ENSURE_OK(context,
|
||||
GetInputSafe(context, node, kInputsTensor, &inputs));
|
||||
const TfLiteTensor* sequence_length;
|
||||
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kSequenceLengthTensor,
|
||||
&sequence_length));
|
||||
const CTCBeamSearchDecoderParams* option =
|
||||
reinterpret_cast<CTCBeamSearchDecoderParams*>(node->user_data);
|
||||
|
||||
@ -207,7 +224,9 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
||||
std::vector<std::vector<std::vector<int>>> best_paths(batch_size);
|
||||
std::vector<float> log_probs;
|
||||
|
||||
TfLiteTensor* log_probabilities = GetOutput(context, node, 3 * top_paths);
|
||||
TfLiteTensor* log_probabilities;
|
||||
TF_LITE_ENSURE_OK(
|
||||
context, GetOutputSafe(context, node, 3 * top_paths, &log_probabilities));
|
||||
float* log_probabilities_output = GetTensorData<float>(log_probabilities);
|
||||
|
||||
// Assumption: the blank index is num_classes - 1
|
||||
|
@ -127,44 +127,55 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
TF_LITE_ENSURE_EQ(context, node->outputs->size, kOutputNum);
|
||||
|
||||
// input's dim = [n_time, n_batch, n_input]
|
||||
const TfLiteTensor* input = GetInput(context, node, kInput);
|
||||
const TfLiteTensor* input;
|
||||
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInput, &input));
|
||||
TF_LITE_ENSURE_EQ(context, input->dims->size, 3);
|
||||
const int n_time = input->dims->data[0];
|
||||
const int n_batch = input->dims->data[1];
|
||||
const int n_input = input->dims->data[2];
|
||||
|
||||
// input_state's dim = [n_batch, n_output]
|
||||
const TfLiteTensor* input_state = GetInput(context, node, kInputState);
|
||||
const TfLiteTensor* input_state;
|
||||
TF_LITE_ENSURE_OK(context,
|
||||
GetInputSafe(context, node, kInputState, &input_state));
|
||||
TF_LITE_ENSURE_EQ(context, input_state->dims->size, 2);
|
||||
TF_LITE_ENSURE_EQ(context, input_state->dims->data[0], n_batch);
|
||||
const int n_output = input_state->dims->data[1];
|
||||
|
||||
// gate_weight' dim = [2 * n_output, n_input + n_output]
|
||||
const TfLiteTensor* gate_weight = GetInput(context, node, kGateWeight);
|
||||
const TfLiteTensor* gate_weight;
|
||||
TF_LITE_ENSURE_OK(context,
|
||||
GetInputSafe(context, node, kGateWeight, &gate_weight));
|
||||
TF_LITE_ENSURE_EQ(context, gate_weight->dims->size, 2);
|
||||
TF_LITE_ENSURE_EQ(context, gate_weight->dims->data[0], 2 * n_output);
|
||||
TF_LITE_ENSURE_EQ(context, gate_weight->dims->data[1], n_input + n_output);
|
||||
|
||||
// gate_bias' dim = [2 * n_output]
|
||||
const TfLiteTensor* gate_bias = GetInput(context, node, kGateBias);
|
||||
const TfLiteTensor* gate_bias;
|
||||
TF_LITE_ENSURE_OK(context,
|
||||
GetInputSafe(context, node, kGateBias, &gate_bias));
|
||||
TF_LITE_ENSURE_EQ(context, gate_bias->dims->size, 1);
|
||||
TF_LITE_ENSURE_EQ(context, gate_bias->dims->data[0], 2 * n_output);
|
||||
|
||||
// candidate_weight' dim = [n_output, n_input + n_output]
|
||||
const TfLiteTensor* candidate_weight =
|
||||
GetInput(context, node, kCandidateWeight);
|
||||
const TfLiteTensor* candidate_weight;
|
||||
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kCandidateWeight,
|
||||
&candidate_weight));
|
||||
TF_LITE_ENSURE_EQ(context, candidate_weight->dims->size, 2);
|
||||
TF_LITE_ENSURE_EQ(context, candidate_weight->dims->data[0], n_output);
|
||||
TF_LITE_ENSURE_EQ(context, candidate_weight->dims->data[1],
|
||||
n_input + n_output);
|
||||
|
||||
// candidate_bias' dim = [n_output]
|
||||
const TfLiteTensor* candidate_bias = GetInput(context, node, kCandidateBias);
|
||||
const TfLiteTensor* candidate_bias;
|
||||
TF_LITE_ENSURE_OK(
|
||||
context, GetInputSafe(context, node, kCandidateBias, &candidate_bias));
|
||||
TF_LITE_ENSURE_EQ(context, candidate_bias->dims->size, 1);
|
||||
TF_LITE_ENSURE_EQ(context, candidate_bias->dims->data[0], n_output);
|
||||
|
||||
// output's dim = [n_time, n_batch, n_output]
|
||||
TfLiteTensor* output = GetOutput(context, node, kOutput);
|
||||
TfLiteTensor* output;
|
||||
TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, kOutput, &output));
|
||||
TfLiteIntArray* output_size = TfLiteIntArrayCreate(3);
|
||||
output_size->data[0] = n_time;
|
||||
output_size->data[1] = n_batch;
|
||||
@ -173,7 +184,9 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
context->ResizeTensor(context, output, output_size));
|
||||
|
||||
// output_state's dim = [n_batch, n_output]
|
||||
TfLiteTensor* output_state = GetOutput(context, node, kOutputState);
|
||||
TfLiteTensor* output_state;
|
||||
TF_LITE_ENSURE_OK(context,
|
||||
GetOutputSafe(context, node, kOutputState, &output_state));
|
||||
TF_LITE_ENSURE_OK(
|
||||
context, context->ResizeTensor(context, output_state,
|
||||
TfLiteIntArrayCopy(input_state->dims)));
|
||||
@ -183,7 +196,9 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
|
||||
// activation's dim = [n_batch, 2 * n_output]
|
||||
node->temporaries->data[kActivation] = *scratch_tensor_index;
|
||||
TfLiteTensor* activation = GetTemporary(context, node, kActivation);
|
||||
TfLiteTensor* activation;
|
||||
TF_LITE_ENSURE_OK(context,
|
||||
GetTemporarySafe(context, node, kActivation, &activation));
|
||||
activation->type = input->type;
|
||||
activation->allocation_type = kTfLiteArenaRw;
|
||||
TfLiteIntArray* activation_size = TfLiteIntArrayCreate(2);
|
||||
@ -194,7 +209,8 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
|
||||
// concat's dim = [n_batch, n_input + n_output]
|
||||
node->temporaries->data[kConcat] = (*scratch_tensor_index) + kConcat;
|
||||
TfLiteTensor* concat = GetTemporary(context, node, kConcat);
|
||||
TfLiteTensor* concat;
|
||||
TF_LITE_ENSURE_OK(context, GetTemporarySafe(context, node, kConcat, &concat));
|
||||
concat->type = input->type;
|
||||
concat->allocation_type = kTfLiteArenaRw;
|
||||
TfLiteIntArray* concat_size = TfLiteIntArrayCreate(2);
|
||||
@ -207,17 +223,33 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
}
|
||||
|
||||
TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
||||
const TfLiteTensor* input = GetInput(context, node, kInput);
|
||||
const TfLiteTensor* input_state = GetInput(context, node, kInputState);
|
||||
const TfLiteTensor* gate_weight = GetInput(context, node, kGateWeight);
|
||||
const TfLiteTensor* gate_bias = GetInput(context, node, kGateBias);
|
||||
const TfLiteTensor* candidate_weight =
|
||||
GetInput(context, node, kCandidateWeight);
|
||||
const TfLiteTensor* candidate_bias = GetInput(context, node, kCandidateBias);
|
||||
TfLiteTensor* output = GetOutput(context, node, kOutput);
|
||||
TfLiteTensor* output_state = GetOutput(context, node, kOutputState);
|
||||
TfLiteTensor* activation = GetTemporary(context, node, kActivation);
|
||||
TfLiteTensor* concat = GetTemporary(context, node, kConcat);
|
||||
const TfLiteTensor* input;
|
||||
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInput, &input));
|
||||
const TfLiteTensor* input_state;
|
||||
TF_LITE_ENSURE_OK(context,
|
||||
GetInputSafe(context, node, kInputState, &input_state));
|
||||
const TfLiteTensor* gate_weight;
|
||||
TF_LITE_ENSURE_OK(context,
|
||||
GetInputSafe(context, node, kGateWeight, &gate_weight));
|
||||
const TfLiteTensor* gate_bias;
|
||||
TF_LITE_ENSURE_OK(context,
|
||||
GetInputSafe(context, node, kGateBias, &gate_bias));
|
||||
const TfLiteTensor* candidate_weight;
|
||||
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kCandidateWeight,
|
||||
&candidate_weight));
|
||||
const TfLiteTensor* candidate_bias;
|
||||
TF_LITE_ENSURE_OK(
|
||||
context, GetInputSafe(context, node, kCandidateBias, &candidate_bias));
|
||||
TfLiteTensor* output;
|
||||
TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, kOutput, &output));
|
||||
TfLiteTensor* output_state;
|
||||
TF_LITE_ENSURE_OK(context,
|
||||
GetOutputSafe(context, node, kOutputState, &output_state));
|
||||
TfLiteTensor* activation;
|
||||
TF_LITE_ENSURE_OK(context,
|
||||
GetTemporarySafe(context, node, kActivation, &activation));
|
||||
TfLiteTensor* concat;
|
||||
TF_LITE_ENSURE_OK(context, GetTemporarySafe(context, node, kConcat, &concat));
|
||||
auto cpu_backend_context = CpuBackendContext::GetFromContext(context);
|
||||
|
||||
if (gate_weight->type == kTfLiteFloat32) {
|
||||
|
@ -91,8 +91,11 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
TF_LITE_ENSURE_EQ(context, NumInputs(node), 1);
|
||||
TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
|
||||
|
||||
const TfLiteTensor* input = GetInput(context, node, kInputTensor);
|
||||
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
|
||||
const TfLiteTensor* input;
|
||||
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputTensor, &input));
|
||||
TfLiteTensor* output;
|
||||
TF_LITE_ENSURE_OK(context,
|
||||
GetOutputSafe(context, node, kOutputTensor, &output));
|
||||
|
||||
TF_LITE_ENSURE_EQ(context, NumDimensions(input), 1);
|
||||
|
||||
@ -180,8 +183,11 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
||||
reinterpret_cast<TfLiteAudioMicrofrontendParams*>(node->user_data);
|
||||
FrontendReset(data->state);
|
||||
|
||||
const TfLiteTensor* input = GetInput(context, node, kInputTensor);
|
||||
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
|
||||
const TfLiteTensor* input;
|
||||
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputTensor, &input));
|
||||
TfLiteTensor* output;
|
||||
TF_LITE_ENSURE_OK(context,
|
||||
GetOutputSafe(context, node, kOutputTensor, &output));
|
||||
|
||||
if (data->out_float) {
|
||||
GenerateFeatures<float>(data, input, output);
|
||||
|
@ -621,8 +621,10 @@ TfLiteRegistration GetPassthroughOpRegistration() {
|
||||
reg.prepare = [](TfLiteContext* context, TfLiteNode* node) {
|
||||
auto* first_new_tensor = static_cast<int*>(node->user_data);
|
||||
|
||||
const TfLiteTensor* tensor0 = GetInput(context, node, 0);
|
||||
TfLiteTensor* tensor1 = GetOutput(context, node, 0);
|
||||
const TfLiteTensor* tensor0;
|
||||
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 0, &tensor0));
|
||||
TfLiteTensor* tensor1;
|
||||
TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, 0, &tensor1));
|
||||
|
||||
TfLiteIntArray* newSize = TfLiteIntArrayCopy(tensor0->dims);
|
||||
TF_LITE_ENSURE_STATUS(context->ResizeTensor(context, tensor1, newSize));
|
||||
@ -646,7 +648,8 @@ TfLiteRegistration GetPassthroughOpRegistration() {
|
||||
return kTfLiteOk;
|
||||
};
|
||||
reg.invoke = [](TfLiteContext* context, TfLiteNode* node) {
|
||||
const TfLiteTensor* a0 = GetInput(context, node, 0);
|
||||
const TfLiteTensor* a0;
|
||||
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 0, &a0));
|
||||
|
||||
auto populate = [&](int id) {
|
||||
TfLiteTensor* t = &context->tensors[id];
|
||||
@ -780,8 +783,10 @@ TEST(BasicInterpreter, ThreeStepAllocate) {
|
||||
// String-in String-out node.
|
||||
TfLiteRegistration reg_copy = {nullptr, nullptr, nullptr, nullptr};
|
||||
reg_copy.invoke = [](TfLiteContext* context, TfLiteNode* node) {
|
||||
const TfLiteTensor* input = GetInput(context, node, 0);
|
||||
TfLiteTensor* output = GetOutput(context, node, 0);
|
||||
const TfLiteTensor* input;
|
||||
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 0, &input));
|
||||
TfLiteTensor* output;
|
||||
TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, 0, &output));
|
||||
DynamicBuffer buf;
|
||||
StringRef str_ref = GetString(input, 0);
|
||||
buf.AddString(str_ref);
|
||||
@ -792,14 +797,17 @@ TEST(BasicInterpreter, ThreeStepAllocate) {
|
||||
// String-in Int-out node.
|
||||
TfLiteRegistration reg_len = {nullptr, nullptr, nullptr, nullptr};
|
||||
reg_len.prepare = [](TfLiteContext* context, TfLiteNode* node) {
|
||||
TfLiteTensor* output = GetOutput(context, node, 0);
|
||||
TfLiteTensor* output;
|
||||
TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, 0, &output));
|
||||
TfLiteIntArray* outputSize = TfLiteIntArrayCreate(1);
|
||||
outputSize->data[0] = 1;
|
||||
return context->ResizeTensor(context, output, outputSize);
|
||||
};
|
||||
reg_len.invoke = [](TfLiteContext* context, TfLiteNode* node) {
|
||||
const TfLiteTensor* a0 = GetInput(context, node, 0);
|
||||
TfLiteTensor* a1 = GetOutput(context, node, 0);
|
||||
const TfLiteTensor* a0;
|
||||
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 0, &a0));
|
||||
TfLiteTensor* a1;
|
||||
TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, 0, &a1));
|
||||
a1->data.i32[0] = a0->bytes;
|
||||
return kTfLiteOk;
|
||||
};
|
||||
@ -848,14 +856,18 @@ TEST(BasicInterpreter, AllocateTwice) {
|
||||
|
||||
TfLiteRegistration reg = {nullptr, nullptr, nullptr, nullptr};
|
||||
reg.prepare = [](TfLiteContext* context, TfLiteNode* node) {
|
||||
const TfLiteTensor* tensor0 = GetInput(context, node, 0);
|
||||
TfLiteTensor* tensor1 = GetOutput(context, node, 0);
|
||||
const TfLiteTensor* tensor0;
|
||||
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 0, &tensor0));
|
||||
TfLiteTensor* tensor1;
|
||||
TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, 0, &tensor1));
|
||||
TfLiteIntArray* newSize = TfLiteIntArrayCopy(tensor0->dims);
|
||||
return context->ResizeTensor(context, tensor1, newSize);
|
||||
};
|
||||
reg.invoke = [](TfLiteContext* context, TfLiteNode* node) {
|
||||
const TfLiteTensor* a0 = GetInput(context, node, 0);
|
||||
TfLiteTensor* a1 = GetOutput(context, node, 0);
|
||||
const TfLiteTensor* a0;
|
||||
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 0, &a0));
|
||||
TfLiteTensor* a1;
|
||||
TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, 0, &a1));
|
||||
int num = a0->dims->data[0];
|
||||
for (int i = 0; i < num; i++) {
|
||||
a1->data.f[i] = a0->data.f[i];
|
||||
@ -1205,8 +1217,10 @@ class TestExecutionPlan : public ::testing::Test {
|
||||
|
||||
reg.prepare = [](TfLiteContext* context, TfLiteNode* node) {
|
||||
// Set output size to input size
|
||||
const TfLiteTensor* tensor0 = GetInput(context, node, 0);
|
||||
TfLiteTensor* tensor1 = GetOutput(context, node, 0);
|
||||
const TfLiteTensor* tensor0;
|
||||
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 0, &tensor0));
|
||||
TfLiteTensor* tensor1;
|
||||
TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, 0, &tensor1));
|
||||
TfLiteIntArray* newSize = TfLiteIntArrayCopy(tensor0->dims);
|
||||
return context->ResizeTensor(context, tensor1, newSize);
|
||||
};
|
||||
@ -1215,8 +1229,10 @@ class TestExecutionPlan : public ::testing::Test {
|
||||
CallReporting* call_reporting =
|
||||
static_cast<CallReporting*>(node->builtin_data);
|
||||
// Copy input data to output data.
|
||||
const TfLiteTensor* a0 = GetInput(context, node, 0);
|
||||
TfLiteTensor* a1 = GetOutput(context, node, 0);
|
||||
const TfLiteTensor* a0;
|
||||
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 0, &a0));
|
||||
TfLiteTensor* a1;
|
||||
TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, 0, &a1));
|
||||
int num = a0->dims->data[0];
|
||||
for (int i = 0; i < num; i++) {
|
||||
a1->data.f[i] = a0->data.f[i];
|
||||
@ -1403,8 +1419,10 @@ class CancellationTest : public ::testing::Test {
|
||||
// Set output size to the input size in CancelOp::Prepare(). Code exists to
|
||||
// have a framework in Prepare. The input and output tensors are not used.
|
||||
reg.prepare = [](TfLiteContext* context, TfLiteNode* node) {
|
||||
const TfLiteTensor* in_tensor = GetInput(context, node, 0);
|
||||
TfLiteTensor* out_tensor = GetOutput(context, node, 0);
|
||||
const TfLiteTensor* in_tensor;
|
||||
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 0, &in_tensor));
|
||||
TfLiteTensor* out_tensor;
|
||||
TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, 0, &out_tensor));
|
||||
TfLiteIntArray* new_size = TfLiteIntArrayCopy(in_tensor->dims);
|
||||
return context->ResizeTensor(context, out_tensor, new_size);
|
||||
};
|
||||
@ -1423,8 +1441,10 @@ class CancellationTest : public ::testing::Test {
|
||||
// Set output size to the input size in OkOp::Prepare(). Code exists to have
|
||||
// a framework in Prepare. The input and output tensors are not used.
|
||||
reg.prepare = [](TfLiteContext* context, TfLiteNode* node) {
|
||||
const TfLiteTensor* in_tensor = GetInput(context, node, 0);
|
||||
TfLiteTensor* out_tensor = GetOutput(context, node, 0);
|
||||
const TfLiteTensor* in_tensor;
|
||||
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 0, &in_tensor));
|
||||
TfLiteTensor* out_tensor;
|
||||
TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, 0, &out_tensor));
|
||||
TfLiteIntArray* new_size = TfLiteIntArrayCopy(in_tensor->dims);
|
||||
return context->ResizeTensor(context, out_tensor, new_size);
|
||||
};
|
||||
|
@ -33,15 +33,21 @@ Java_org_tensorflow_lite_InterpreterTest_getNativeHandleForDelegate(
|
||||
.free = nullptr,
|
||||
.prepare =
|
||||
[](TfLiteContext* context, TfLiteNode* node) {
|
||||
const TfLiteTensor* input = tflite::GetInput(context, node, 0);
|
||||
TfLiteTensor* output = tflite::GetOutput(context, node, 0);
|
||||
const TfLiteTensor* input;
|
||||
TF_LITE_ENSURE_OK(context,
|
||||
tflite::GetInputSafe(context, node, 0, &input));
|
||||
TfLiteTensor* output;
|
||||
TF_LITE_ENSURE_OK(context,
|
||||
tflite::GetOutputSafe(context, node, 0, &output));
|
||||
TfLiteIntArray* output_dims = TfLiteIntArrayCopy(input->dims);
|
||||
output->type = kTfLiteFloat32;
|
||||
return context->ResizeTensor(context, output, output_dims);
|
||||
},
|
||||
.invoke =
|
||||
[](TfLiteContext* context, TfLiteNode* node) {
|
||||
TfLiteTensor* output = tflite::GetOutput(context, node, 0);
|
||||
TfLiteTensor* output;
|
||||
TF_LITE_ENSURE_OK(context,
|
||||
tflite::GetOutputSafe(context, node, 0, &output));
|
||||
std::fill(output->data.f,
|
||||
output->data.f + tflite::NumElements(output), 7.0f);
|
||||
return kTfLiteOk;
|
||||
|
@ -80,9 +80,9 @@ TfLiteStatus PrepareHashtable(TfLiteContext* context, TfLiteNode* node) {
|
||||
(params->key_dtype == kTfLiteString &&
|
||||
params->value_dtype == kTfLiteInt64));
|
||||
|
||||
TfLiteTensor* resource_handle_tensor =
|
||||
GetOutput(context, node, kResourceHandleTensor);
|
||||
TF_LITE_ENSURE(context, resource_handle_tensor != nullptr);
|
||||
TfLiteTensor* resource_handle_tensor;
|
||||
TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, kResourceHandleTensor,
|
||||
&resource_handle_tensor));
|
||||
TF_LITE_ENSURE_EQ(context, resource_handle_tensor->type, kTfLiteInt32);
|
||||
TfLiteIntArray* outputSize = TfLiteIntArrayCreate(1);
|
||||
outputSize->data[0] = 1;
|
||||
@ -97,8 +97,9 @@ TfLiteStatus EvalHashtable(TfLiteContext* context, TfLiteNode* node) {
|
||||
// The resource id is generated based on the given table name.
|
||||
const int resource_id = std::hash<std::string>{}(params->table_name);
|
||||
|
||||
TfLiteTensor* resource_handle_tensor =
|
||||
GetOutput(context, node, kResourceHandleTensor);
|
||||
TfLiteTensor* resource_handle_tensor;
|
||||
TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, kResourceHandleTensor,
|
||||
&resource_handle_tensor));
|
||||
auto* resource_handle_data =
|
||||
GetTensorData<std::int32_t>(resource_handle_tensor);
|
||||
resource_handle_data[0] = resource_id;
|
||||
|
@ -34,17 +34,23 @@ TfLiteStatus PrepareHashtableFind(TfLiteContext* context, TfLiteNode* node) {
|
||||
TF_LITE_ENSURE_EQ(context, NumInputs(node), 3);
|
||||
TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
|
||||
|
||||
const TfLiteTensor* input_resource_id_tensor =
|
||||
GetInput(context, node, kInputResourceIdTensor);
|
||||
const TfLiteTensor* input_resource_id_tensor;
|
||||
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputResourceIdTensor,
|
||||
&input_resource_id_tensor));
|
||||
TF_LITE_ENSURE_EQ(context, input_resource_id_tensor->type, kTfLiteInt32);
|
||||
TF_LITE_ENSURE_EQ(context, NumDimensions(input_resource_id_tensor), 1);
|
||||
TF_LITE_ENSURE_EQ(context, SizeOfDimension(input_resource_id_tensor, 0), 1);
|
||||
|
||||
const TfLiteTensor* default_value_tensor =
|
||||
GetInput(context, node, kDefaultValueTensor);
|
||||
const TfLiteTensor* default_value_tensor;
|
||||
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kDefaultValueTensor,
|
||||
&default_value_tensor));
|
||||
|
||||
const TfLiteTensor* key_tensor = GetInput(context, node, kKeyTensor);
|
||||
TfLiteTensor* output_tensor = GetOutput(context, node, kOutputTensor);
|
||||
const TfLiteTensor* key_tensor;
|
||||
TF_LITE_ENSURE_OK(context,
|
||||
GetInputSafe(context, node, kKeyTensor, &key_tensor));
|
||||
TfLiteTensor* output_tensor;
|
||||
TF_LITE_ENSURE_OK(
|
||||
context, GetOutputSafe(context, node, kOutputTensor, &output_tensor));
|
||||
TF_LITE_ENSURE_EQ(context, default_value_tensor->type, output_tensor->type);
|
||||
TF_LITE_ENSURE(context, (key_tensor->type == kTfLiteInt64 &&
|
||||
output_tensor->type == kTfLiteString) ||
|
||||
@ -55,14 +61,19 @@ TfLiteStatus PrepareHashtableFind(TfLiteContext* context, TfLiteNode* node) {
|
||||
}
|
||||
|
||||
TfLiteStatus EvalHashtableFind(TfLiteContext* context, TfLiteNode* node) {
|
||||
const TfLiteTensor* input_resource_id_tensor =
|
||||
GetInput(context, node, kInputResourceIdTensor);
|
||||
const TfLiteTensor* input_resource_id_tensor;
|
||||
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputResourceIdTensor,
|
||||
&input_resource_id_tensor));
|
||||
int resource_id = input_resource_id_tensor->data.i32[0];
|
||||
|
||||
const TfLiteTensor* key_tensor = GetInput(context, node, kKeyTensor);
|
||||
const TfLiteTensor* default_value_tensor =
|
||||
GetInput(context, node, kDefaultValueTensor);
|
||||
TfLiteTensor* output_tensor = GetOutput(context, node, 0);
|
||||
const TfLiteTensor* key_tensor;
|
||||
TF_LITE_ENSURE_OK(context,
|
||||
GetInputSafe(context, node, kKeyTensor, &key_tensor));
|
||||
const TfLiteTensor* default_value_tensor;
|
||||
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kDefaultValueTensor,
|
||||
&default_value_tensor));
|
||||
TfLiteTensor* output_tensor;
|
||||
TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, 0, &output_tensor));
|
||||
|
||||
Subgraph* subgraph = reinterpret_cast<Subgraph*>(context->impl_);
|
||||
auto& resources = subgraph->resources();
|
||||
|
@ -33,14 +33,19 @@ TfLiteStatus PrepareHashtableImport(TfLiteContext* context, TfLiteNode* node) {
|
||||
TF_LITE_ENSURE_EQ(context, NumInputs(node), 3);
|
||||
TF_LITE_ENSURE_EQ(context, NumOutputs(node), 0);
|
||||
|
||||
const TfLiteTensor* input_resource_id_tensor =
|
||||
GetInput(context, node, kInputResourceIdTensor);
|
||||
const TfLiteTensor* input_resource_id_tensor;
|
||||
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputResourceIdTensor,
|
||||
&input_resource_id_tensor));
|
||||
TF_LITE_ENSURE_EQ(context, input_resource_id_tensor->type, kTfLiteInt32);
|
||||
TF_LITE_ENSURE_EQ(context, NumDimensions(input_resource_id_tensor), 1);
|
||||
TF_LITE_ENSURE_EQ(context, SizeOfDimension(input_resource_id_tensor, 0), 1);
|
||||
|
||||
const TfLiteTensor* key_tensor = GetInput(context, node, kKeyTensor);
|
||||
const TfLiteTensor* value_tensor = GetInput(context, node, kValueTensor);
|
||||
const TfLiteTensor* key_tensor;
|
||||
TF_LITE_ENSURE_OK(context,
|
||||
GetInputSafe(context, node, kKeyTensor, &key_tensor));
|
||||
const TfLiteTensor* value_tensor;
|
||||
TF_LITE_ENSURE_OK(context,
|
||||
GetInputSafe(context, node, kValueTensor, &value_tensor));
|
||||
TF_LITE_ENSURE(context, (key_tensor->type == kTfLiteInt64 &&
|
||||
value_tensor->type == kTfLiteString) ||
|
||||
(key_tensor->type == kTfLiteString &&
|
||||
@ -52,12 +57,17 @@ TfLiteStatus PrepareHashtableImport(TfLiteContext* context, TfLiteNode* node) {
|
||||
}
|
||||
|
||||
TfLiteStatus EvalHashtableImport(TfLiteContext* context, TfLiteNode* node) {
|
||||
const TfLiteTensor* input_resource_id_tensor =
|
||||
GetInput(context, node, kInputResourceIdTensor);
|
||||
const TfLiteTensor* input_resource_id_tensor;
|
||||
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputResourceIdTensor,
|
||||
&input_resource_id_tensor));
|
||||
const int resource_id = input_resource_id_tensor->data.i32[0];
|
||||
|
||||
const TfLiteTensor* key_tensor = GetInput(context, node, kKeyTensor);
|
||||
const TfLiteTensor* value_tensor = GetInput(context, node, kValueTensor);
|
||||
const TfLiteTensor* key_tensor;
|
||||
TF_LITE_ENSURE_OK(context,
|
||||
GetInputSafe(context, node, kKeyTensor, &key_tensor));
|
||||
const TfLiteTensor* value_tensor;
|
||||
TF_LITE_ENSURE_OK(context,
|
||||
GetInputSafe(context, node, kValueTensor, &value_tensor));
|
||||
|
||||
Subgraph* subgraph = reinterpret_cast<Subgraph*>(context->impl_);
|
||||
auto& resources = subgraph->resources();
|
||||
|
@ -32,14 +32,16 @@ TfLiteStatus PrepareHashtableSize(TfLiteContext* context, TfLiteNode* node) {
|
||||
TF_LITE_ENSURE_EQ(context, NumInputs(node), 1);
|
||||
TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
|
||||
|
||||
const TfLiteTensor* input_resource_id_tensor =
|
||||
GetInput(context, node, kInputResourceIdTensor);
|
||||
const TfLiteTensor* input_resource_id_tensor;
|
||||
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputResourceIdTensor,
|
||||
&input_resource_id_tensor));
|
||||
TF_LITE_ENSURE_EQ(context, input_resource_id_tensor->type, kTfLiteInt32);
|
||||
TF_LITE_ENSURE_EQ(context, NumDimensions(input_resource_id_tensor), 1);
|
||||
TF_LITE_ENSURE_EQ(context, SizeOfDimension(input_resource_id_tensor, 0), 1);
|
||||
|
||||
TfLiteTensor* output_tensor = GetOutput(context, node, kOutputTensor);
|
||||
TF_LITE_ENSURE(context, output_tensor != nullptr);
|
||||
TfLiteTensor* output_tensor;
|
||||
TF_LITE_ENSURE_OK(
|
||||
context, GetOutputSafe(context, node, kOutputTensor, &output_tensor));
|
||||
TF_LITE_ENSURE_EQ(context, output_tensor->type, kTfLiteInt64);
|
||||
TfLiteIntArray* outputSize = TfLiteIntArrayCreate(1);
|
||||
outputSize->data[0] = 1;
|
||||
@ -47,11 +49,14 @@ TfLiteStatus PrepareHashtableSize(TfLiteContext* context, TfLiteNode* node) {
|
||||
}
|
||||
|
||||
TfLiteStatus EvalHashtableSize(TfLiteContext* context, TfLiteNode* node) {
|
||||
const TfLiteTensor* input_resource_id_tensor =
|
||||
GetInput(context, node, kInputResourceIdTensor);
|
||||
const TfLiteTensor* input_resource_id_tensor;
|
||||
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputResourceIdTensor,
|
||||
&input_resource_id_tensor));
|
||||
int resource_id = input_resource_id_tensor->data.i32[0];
|
||||
|
||||
TfLiteTensor* output_tensor = GetOutput(context, node, kOutputTensor);
|
||||
TfLiteTensor* output_tensor;
|
||||
TF_LITE_ENSURE_OK(
|
||||
context, GetOutputSafe(context, node, kOutputTensor, &output_tensor));
|
||||
auto* output_data = GetTensorData<std::int64_t>(output_tensor);
|
||||
|
||||
Subgraph* subgraph = reinterpret_cast<Subgraph*>(context->impl_);
|
||||
|
@ -30,19 +30,70 @@ namespace tflite {
|
||||
|
||||
namespace {
|
||||
|
||||
inline TfLiteTensor* GetMutableInput(const TfLiteContext* context,
|
||||
const TfLiteNode* node, int index) {
|
||||
if (index >= 0 && index < node->inputs->size) {
|
||||
const int tensor_index = node->inputs->data[index];
|
||||
// Assumes tensor_index is a valid index (in bounds)
|
||||
inline TfLiteTensor* GetTensorAtIndex(const TfLiteContext* context,
|
||||
int tensor_index) {
|
||||
if (context->tensors != nullptr) {
|
||||
return &context->tensors[tensor_index];
|
||||
} else {
|
||||
return context->GetTensor(context, tensor_index);
|
||||
}
|
||||
}
|
||||
|
||||
// Validate in a single place to reduce binary size
|
||||
inline TfLiteStatus ValidateTensorIndexingSafe(const TfLiteContext* context,
|
||||
int index, int max_size,
|
||||
const int* tensor_indices,
|
||||
int* tensor_index) {
|
||||
if (index < 0 || index >= max_size) {
|
||||
TF_LITE_KERNEL_LOG(const_cast<TfLiteContext*>(context),
|
||||
"Invalid tensor index %d (not in [0, %d))\n", index,
|
||||
max_size);
|
||||
return kTfLiteError;
|
||||
}
|
||||
if (tensor_indices[index] == kTfLiteOptionalTensor) {
|
||||
TF_LITE_KERNEL_LOG(const_cast<TfLiteContext*>(context),
|
||||
"Tensor at index %d was optional but was expected\n",
|
||||
index);
|
||||
return kTfLiteError;
|
||||
}
|
||||
|
||||
*tensor_index = tensor_indices[index];
|
||||
return kTfLiteOk;
|
||||
}
|
||||
|
||||
// Same as above but returns -1 for invalid inputs instead of status + logging
|
||||
// error.
|
||||
inline int ValidateTensorIndexing(const TfLiteContext* context, int index,
|
||||
int max_size, const int* tensor_indices) {
|
||||
if (index >= 0 && index < max_size) {
|
||||
const int tensor_index = tensor_indices[index];
|
||||
if (tensor_index != kTfLiteOptionalTensor) {
|
||||
if (context->tensors != nullptr) {
|
||||
return &context->tensors[tensor_index];
|
||||
} else {
|
||||
return context->GetTensor(context, tensor_index);
|
||||
}
|
||||
return tensor_index;
|
||||
}
|
||||
}
|
||||
return nullptr;
|
||||
return -1;
|
||||
}
|
||||
|
||||
inline TfLiteTensor* GetMutableInput(const TfLiteContext* context,
|
||||
const TfLiteNode* node, int index) {
|
||||
const int tensor_index = ValidateTensorIndexing(
|
||||
context, index, node->inputs->size, node->inputs->data);
|
||||
if (tensor_index < 0) {
|
||||
return nullptr;
|
||||
}
|
||||
return GetTensorAtIndex(context, tensor_index);
|
||||
}
|
||||
|
||||
inline TfLiteStatus GetMutableInputSafe(const TfLiteContext* context,
|
||||
const TfLiteNode* node, int index,
|
||||
const TfLiteTensor** tensor) {
|
||||
int tensor_index;
|
||||
TF_LITE_ENSURE_OK(
|
||||
context, ValidateTensorIndexingSafe(context, index, node->inputs->size,
|
||||
node->inputs->data, &tensor_index));
|
||||
*tensor = GetTensorAtIndex(context, tensor_index);
|
||||
return kTfLiteOk;
|
||||
}
|
||||
|
||||
} // anonymous namespace.
|
||||
@ -52,6 +103,11 @@ const TfLiteTensor* GetInput(const TfLiteContext* context,
|
||||
return GetMutableInput(context, node, index);
|
||||
}
|
||||
|
||||
TfLiteStatus GetInputSafe(const TfLiteContext* context, const TfLiteNode* node,
|
||||
int index, const TfLiteTensor** tensor) {
|
||||
return GetMutableInputSafe(context, node, index, tensor);
|
||||
}
|
||||
|
||||
TfLiteTensor* GetVariableInput(TfLiteContext* context, const TfLiteNode* node,
|
||||
int index) {
|
||||
TfLiteTensor* tensor = GetMutableInput(context, node, index);
|
||||
@ -60,17 +116,22 @@ TfLiteTensor* GetVariableInput(TfLiteContext* context, const TfLiteNode* node,
|
||||
|
||||
TfLiteTensor* GetOutput(TfLiteContext* context, const TfLiteNode* node,
|
||||
int index) {
|
||||
if (index >= 0 && index < node->outputs->size) {
|
||||
const int tensor_index = node->outputs->data[index];
|
||||
if (tensor_index != kTfLiteOptionalTensor) {
|
||||
if (context->tensors != nullptr) {
|
||||
return &context->tensors[tensor_index];
|
||||
} else {
|
||||
return context->GetTensor(context, tensor_index);
|
||||
}
|
||||
}
|
||||
const int tensor_index = ValidateTensorIndexing(
|
||||
context, index, node->outputs->size, node->outputs->data);
|
||||
if (tensor_index < 0) {
|
||||
return nullptr;
|
||||
}
|
||||
return nullptr;
|
||||
return GetTensorAtIndex(context, tensor_index);
|
||||
}
|
||||
|
||||
TfLiteStatus GetOutputSafe(const TfLiteContext* context, const TfLiteNode* node,
|
||||
int index, TfLiteTensor** tensor) {
|
||||
int tensor_index;
|
||||
TF_LITE_ENSURE_OK(
|
||||
context, ValidateTensorIndexingSafe(context, index, node->outputs->size,
|
||||
node->outputs->data, &tensor_index));
|
||||
*tensor = GetTensorAtIndex(context, tensor_index);
|
||||
return kTfLiteOk;
|
||||
}
|
||||
|
||||
const TfLiteTensor* GetOptionalInputTensor(const TfLiteContext* context,
|
||||
@ -78,6 +139,50 @@ const TfLiteTensor* GetOptionalInputTensor(const TfLiteContext* context,
|
||||
return GetInput(context, node, index);
|
||||
}
|
||||
|
||||
#ifndef TF_LITE_STATIC_MEMORY
|
||||
TfLiteTensor* GetTemporary(TfLiteContext* context, const TfLiteNode* node,
|
||||
int index) {
|
||||
const int tensor_index = ValidateTensorIndexing(
|
||||
context, index, node->temporaries->size, node->temporaries->data);
|
||||
if (tensor_index < 0) {
|
||||
return nullptr;
|
||||
}
|
||||
return GetTensorAtIndex(context, tensor_index);
|
||||
}
|
||||
|
||||
TfLiteStatus GetTemporarySafe(const TfLiteContext* context,
|
||||
const TfLiteNode* node, int index,
|
||||
TfLiteTensor** tensor) {
|
||||
int tensor_index;
|
||||
TF_LITE_ENSURE_OK(context, ValidateTensorIndexingSafe(
|
||||
context, index, node->temporaries->size,
|
||||
node->temporaries->data, &tensor_index));
|
||||
*tensor = GetTensorAtIndex(context, tensor_index);
|
||||
return kTfLiteOk;
|
||||
}
|
||||
|
||||
const TfLiteTensor* GetIntermediates(TfLiteContext* context,
|
||||
const TfLiteNode* node, int index) {
|
||||
const int tensor_index = ValidateTensorIndexing(
|
||||
context, index, node->intermediates->size, node->intermediates->data);
|
||||
if (tensor_index < 0) {
|
||||
return nullptr;
|
||||
}
|
||||
return GetTensorAtIndex(context, tensor_index);
|
||||
}
|
||||
|
||||
TfLiteStatus GetIntermediatesSafe(const TfLiteContext* context,
|
||||
const TfLiteNode* node, int index,
|
||||
TfLiteTensor** tensor) {
|
||||
int tensor_index;
|
||||
TF_LITE_ENSURE_OK(context, ValidateTensorIndexingSafe(
|
||||
context, index, node->intermediates->size,
|
||||
node->intermediates->data, &tensor_index));
|
||||
*tensor = GetTensorAtIndex(context, tensor_index);
|
||||
return kTfLiteOk;
|
||||
}
|
||||
#endif // TF_LITE_STATIC_MEMORY
|
||||
|
||||
// Per-axis
|
||||
TfLiteStatus PopulateConvolutionQuantizationParams(
|
||||
TfLiteContext* context, const TfLiteTensor* input,
|
||||
|
@ -40,6 +40,17 @@ namespace tflite {
|
||||
const TfLiteTensor* GetInput(const TfLiteContext* context,
|
||||
const TfLiteNode* node, int index);
|
||||
|
||||
// Same as `GetInput` but returns boolean and uses output argument for tensor.
|
||||
//
|
||||
// TfLiteTensor* my_tensor;
|
||||
// TF_LITE_ENSURE_OK(context,
|
||||
// GetInputSafe(context, node, kMyTensorIdx, &my_tensor));
|
||||
// // can use my_tensor directly from here onwards, it is not nullptr
|
||||
//
|
||||
// Should be used in cases where the binary size is too large.
|
||||
TfLiteStatus GetInputSafe(const TfLiteContext* context, const TfLiteNode* node,
|
||||
int index, const TfLiteTensor** tensor);
|
||||
|
||||
// Note: You must check if result is not null:
|
||||
//
|
||||
// TfLiteTensor* my_tensor = GetVariableInput(context, node, kMyTensorIdx);
|
||||
@ -60,6 +71,17 @@ TfLiteTensor* GetVariableInput(TfLiteContext* context, const TfLiteNode* node,
|
||||
TfLiteTensor* GetOutput(TfLiteContext* context, const TfLiteNode* node,
|
||||
int index);
|
||||
|
||||
// Same as `GetOutput` but returns boolean and uses output argument for tensor.
|
||||
//
|
||||
// TfLiteTensor* my_tensor;
|
||||
// TF_LITE_ENSURE_OK(context,
|
||||
// GetOutputSafe(context, node, kMyTensorIdx, &my_tensor));
|
||||
// // can use my_tensor directly from here onwards, it is not nullptr
|
||||
//
|
||||
// Should be used in cases where the binary size is too large.
|
||||
TfLiteStatus GetOutputSafe(const TfLiteContext* context, const TfLiteNode* node,
|
||||
int index, TfLiteTensor** tensor);
|
||||
|
||||
// Note: You must check if result is not null:
|
||||
//
|
||||
// TfLiteTensor* my_tensor = GetOptionalInputTensor(context, node, kIdx);
|
||||
@ -72,11 +94,6 @@ TfLiteTensor* GetOutput(TfLiteContext* context, const TfLiteNode* node,
|
||||
const TfLiteTensor* GetOptionalInputTensor(const TfLiteContext* context,
|
||||
const TfLiteNode* node, int index);
|
||||
|
||||
inline int NumDimensions(const TfLiteTensor* t) { return t->dims->size; }
|
||||
inline int SizeOfDimension(const TfLiteTensor* t, int dim) {
|
||||
return t->dims->data[dim];
|
||||
}
|
||||
|
||||
#ifndef TF_LITE_STATIC_MEMORY
|
||||
// Note: You must check if result is not null:
|
||||
//
|
||||
@ -85,18 +102,22 @@ inline int SizeOfDimension(const TfLiteTensor* t, int dim) {
|
||||
//
|
||||
// This is because the index might point to the optional tensor constant
|
||||
// (kTfLiteOptionalTensor) in which case there is no tensor to return.
|
||||
inline TfLiteTensor* GetTemporary(TfLiteContext* context,
|
||||
const TfLiteNode* node, int index) {
|
||||
if (index >= 0 && index < node->temporaries->size) {
|
||||
const int tensor_index = node->temporaries->data[index];
|
||||
if (tensor_index != kTfLiteOptionalTensor) {
|
||||
if (context->tensors != nullptr) {
|
||||
return &context->tensors[tensor_index];
|
||||
}
|
||||
}
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
TfLiteTensor* GetTemporary(TfLiteContext* context, const TfLiteNode* node,
|
||||
int index);
|
||||
|
||||
// Same as `GetTemporary` but returns boolean and uses output argument for
|
||||
// tensor.
|
||||
//
|
||||
// TfLiteTensor* my_tensor;
|
||||
// TF_LITE_ENSURE_OK(context,
|
||||
// GetTemporarySafe(context, node, kMyTensorIdx,
|
||||
// &my_tensor));
|
||||
// // can use my_tensor directly from here onwards, it is not nullptr
|
||||
//
|
||||
// Should be used in cases where the binary size is too large.
|
||||
TfLiteStatus GetTemporarySafe(const TfLiteContext* context,
|
||||
const TfLiteNode* node, int index,
|
||||
TfLiteTensor** tensor);
|
||||
|
||||
// Note: You must check if result is not null:
|
||||
//
|
||||
@ -105,25 +126,37 @@ inline TfLiteTensor* GetTemporary(TfLiteContext* context,
|
||||
//
|
||||
// This is because the index might point to the optional tensor constant
|
||||
// (kTfLiteOptionalTensor) in which case there is no tensor to return.
|
||||
inline const TfLiteTensor* GetIntermediates(TfLiteContext* context,
|
||||
const TfLiteNode* node, int index) {
|
||||
if (index >= 0 && index < node->intermediates->size) {
|
||||
const int tensor_index = node->intermediates->data[index];
|
||||
if (tensor_index != kTfLiteOptionalTensor) {
|
||||
if (context->tensors != nullptr) {
|
||||
return &context->tensors[tensor_index];
|
||||
}
|
||||
}
|
||||
}
|
||||
return nullptr;
|
||||
const TfLiteTensor* GetIntermediates(TfLiteContext* context,
|
||||
const TfLiteNode* node, int index);
|
||||
|
||||
// Same as `GetIntermediates` but returns boolean and uses output argument for
|
||||
// tensor.
|
||||
//
|
||||
// TfLiteTensor* my_tensor;
|
||||
// TF_LITE_ENSURE_OK(context,
|
||||
// GetIntermediatesSafe(context, node, kMyTensorIdx,
|
||||
// &my_tensor));
|
||||
// // can use my_tensor directly from here onwards, it is not nullptr
|
||||
//
|
||||
// Should be used in cases where the binary size is too large.
|
||||
TfLiteStatus GetIntermediatesSafe(const TfLiteContext* context,
|
||||
const TfLiteNode* node, int index,
|
||||
TfLiteTensor** tensor);
|
||||
#endif // TF_LITE_STATIC_MEMORY
|
||||
|
||||
inline int NumDimensions(const TfLiteTensor* t) { return t->dims->size; }
|
||||
inline int SizeOfDimension(const TfLiteTensor* t, int dim) {
|
||||
return t->dims->data[dim];
|
||||
}
|
||||
|
||||
inline int NumInputs(const TfLiteNode* node) { return node->inputs->size; }
|
||||
inline int NumOutputs(const TfLiteNode* node) { return node->outputs->size; }
|
||||
|
||||
#ifndef TF_LITE_STATIC_MEMORY
|
||||
inline int NumIntermediates(const TfLiteNode* node) {
|
||||
return node->intermediates->size;
|
||||
}
|
||||
#endif // TF_LITE_STATIC_MEMORY
|
||||
inline int NumInputs(const TfLiteNode* node) { return node->inputs->size; }
|
||||
inline int NumOutputs(const TfLiteNode* node) { return node->outputs->size; }
|
||||
|
||||
inline int64_t NumElements(const TfLiteIntArray* dims) {
|
||||
int64_t count = 1;
|
||||
|
@ -36,10 +36,14 @@ namespace {
|
||||
const char* kOpName = "SimpleOpEval";
|
||||
|
||||
TfLiteStatus SimpleOpEval(TfLiteContext* context, TfLiteNode* node) {
|
||||
const TfLiteTensor* input1 = tflite::GetInput(context, node, /*index=*/0);
|
||||
const TfLiteTensor* input2 = tflite::GetInput(context, node, /*index=*/1);
|
||||
const TfLiteTensor* input1;
|
||||
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, /*index=*/0, &input1));
|
||||
const TfLiteTensor* input2;
|
||||
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, /*index=*/1, &input2));
|
||||
|
||||
TfLiteTensor* output = GetOutput(context, node, /*index=*/0);
|
||||
TfLiteTensor* output;
|
||||
TF_LITE_ENSURE_OK(context,
|
||||
GetOutputSafe(context, node, /*index=*/0, &output));
|
||||
|
||||
int32_t* output_data = output->data.i32;
|
||||
*output_data = *(input1->data.i32) + *(input2->data.i32);
|
||||
|
@ -466,26 +466,51 @@ TfLiteStatus lstm_eval(TfLiteContext* context, TfLiteNode* node, Logger* logger,
|
||||
ErrorReporter* error_reporter) {
|
||||
const auto* params = static_cast<TfLiteLSTMParams*>(node->builtin_data);
|
||||
|
||||
const TfLiteTensor* input =
|
||||
GetInput(context, node, ops::builtin::lstm::full::kInputTensor);
|
||||
const TfLiteTensor* input;
|
||||
TF_LITE_ENSURE_OK(
|
||||
context, GetInputSafe(context, node,
|
||||
ops::builtin::lstm::full::kInputTensor, &input));
|
||||
|
||||
const TfLiteTensor* input_to_input_weights = GetOptionalInputTensor(
|
||||
context, node, ops::builtin::lstm::full::kInputToInputWeightsTensor);
|
||||
const TfLiteTensor* input_to_forget_weights = GetInput(
|
||||
context, node, ops::builtin::lstm::full::kInputToForgetWeightsTensor);
|
||||
const TfLiteTensor* input_to_cell_weights = GetInput(
|
||||
context, node, ops::builtin::lstm::full::kInputToCellWeightsTensor);
|
||||
const TfLiteTensor* input_to_output_weights = GetInput(
|
||||
context, node, ops::builtin::lstm::full::kInputToOutputWeightsTensor);
|
||||
const TfLiteTensor* input_to_forget_weights;
|
||||
TF_LITE_ENSURE_OK(
|
||||
context,
|
||||
GetInputSafe(context, node,
|
||||
ops::builtin::lstm::full::kInputToForgetWeightsTensor,
|
||||
&input_to_forget_weights));
|
||||
const TfLiteTensor* input_to_cell_weights;
|
||||
TF_LITE_ENSURE_OK(
|
||||
context, GetInputSafe(context, node,
|
||||
ops::builtin::lstm::full::kInputToCellWeightsTensor,
|
||||
&input_to_cell_weights));
|
||||
const TfLiteTensor* input_to_output_weights;
|
||||
TF_LITE_ENSURE_OK(
|
||||
context,
|
||||
GetInputSafe(context, node,
|
||||
ops::builtin::lstm::full::kInputToOutputWeightsTensor,
|
||||
&input_to_output_weights));
|
||||
|
||||
const TfLiteTensor* recurrent_to_input_weights = GetOptionalInputTensor(
|
||||
context, node, ops::builtin::lstm::full::kRecurrentToInputWeightsTensor);
|
||||
const TfLiteTensor* recurrent_to_forget_weights = GetInput(
|
||||
context, node, ops::builtin::lstm::full::kRecurrentToForgetWeightsTensor);
|
||||
const TfLiteTensor* recurrent_to_cell_weights = GetInput(
|
||||
context, node, ops::builtin::lstm::full::kRecurrentToCellWeightsTensor);
|
||||
const TfLiteTensor* recurrent_to_output_weights = GetInput(
|
||||
context, node, ops::builtin::lstm::full::kRecurrentToOutputWeightsTensor);
|
||||
const TfLiteTensor* recurrent_to_forget_weights;
|
||||
TF_LITE_ENSURE_OK(
|
||||
context,
|
||||
GetInputSafe(context, node,
|
||||
ops::builtin::lstm::full::kRecurrentToForgetWeightsTensor,
|
||||
&recurrent_to_forget_weights));
|
||||
const TfLiteTensor* recurrent_to_cell_weights;
|
||||
TF_LITE_ENSURE_OK(
|
||||
context,
|
||||
GetInputSafe(context, node,
|
||||
ops::builtin::lstm::full::kRecurrentToCellWeightsTensor,
|
||||
&recurrent_to_cell_weights));
|
||||
const TfLiteTensor* recurrent_to_output_weights;
|
||||
TF_LITE_ENSURE_OK(
|
||||
context,
|
||||
GetInputSafe(context, node,
|
||||
ops::builtin::lstm::full::kRecurrentToOutputWeightsTensor,
|
||||
&recurrent_to_output_weights));
|
||||
|
||||
const TfLiteTensor* cell_to_input_weights = GetOptionalInputTensor(
|
||||
context, node, ops::builtin::lstm::full::kCellToInputWeightsTensor);
|
||||
@ -509,12 +534,21 @@ TfLiteStatus lstm_eval(TfLiteContext* context, TfLiteNode* node, Logger* logger,
|
||||
|
||||
const TfLiteTensor* input_gate_bias = GetOptionalInputTensor(
|
||||
context, node, ops::builtin::lstm::full::kInputGateBiasTensor);
|
||||
const TfLiteTensor* forget_gate_bias =
|
||||
GetInput(context, node, ops::builtin::lstm::full::kForgetGateBiasTensor);
|
||||
const TfLiteTensor* cell_gate_bias =
|
||||
GetInput(context, node, ops::builtin::lstm::full::kCellGateBiasTensor);
|
||||
const TfLiteTensor* output_gate_bias =
|
||||
GetInput(context, node, ops::builtin::lstm::full::kOutputGateBiasTensor);
|
||||
const TfLiteTensor* forget_gate_bias;
|
||||
TF_LITE_ENSURE_OK(
|
||||
context, GetInputSafe(context, node,
|
||||
ops::builtin::lstm::full::kForgetGateBiasTensor,
|
||||
&forget_gate_bias));
|
||||
const TfLiteTensor* cell_gate_bias;
|
||||
TF_LITE_ENSURE_OK(
|
||||
context,
|
||||
GetInputSafe(context, node, ops::builtin::lstm::full::kCellGateBiasTensor,
|
||||
&cell_gate_bias));
|
||||
const TfLiteTensor* output_gate_bias;
|
||||
TF_LITE_ENSURE_OK(
|
||||
context, GetInputSafe(context, node,
|
||||
ops::builtin::lstm::full::kOutputGateBiasTensor,
|
||||
&output_gate_bias));
|
||||
|
||||
const TfLiteTensor* projection_weights = GetOptionalInputTensor(
|
||||
context, node, ops::builtin::lstm::full::kProjectionWeightsTensor);
|
||||
@ -522,7 +556,9 @@ TfLiteStatus lstm_eval(TfLiteContext* context, TfLiteNode* node, Logger* logger,
|
||||
context, node, ops::builtin::lstm::full::kProjectionBiasTensor);
|
||||
|
||||
// Index the scratch buffers pointers to the global scratch buffer.
|
||||
TfLiteTensor* scratch_buffer = GetTemporary(context, node, /*index=*/0);
|
||||
TfLiteTensor* scratch_buffer;
|
||||
TF_LITE_ENSURE_OK(
|
||||
context, GetTemporarySafe(context, node, /*index=*/0, &scratch_buffer));
|
||||
|
||||
TfLiteTensor* output_state = GetVariableInput(
|
||||
context, node, ops::builtin::lstm::full::kOutputStateTensor);
|
||||
@ -531,8 +567,10 @@ TfLiteStatus lstm_eval(TfLiteContext* context, TfLiteNode* node, Logger* logger,
|
||||
context, node, ops::builtin::lstm::full::kCellStateTensor);
|
||||
TF_LITE_ENSURE(context, cell_state != nullptr);
|
||||
|
||||
TfLiteTensor* output =
|
||||
GetOutput(context, node, ops::builtin::lstm::full::kOutputTensor);
|
||||
TfLiteTensor* output;
|
||||
TF_LITE_ENSURE_OK(
|
||||
context, GetOutputSafe(context, node,
|
||||
ops::builtin::lstm::full::kOutputTensor, &output));
|
||||
|
||||
std::vector<int> intermediate_tensor_indexes(node->intermediates->size);
|
||||
for (int i = 0; i < node->intermediates->size; ++i) {
|
||||
|
Loading…
Reference in New Issue
Block a user