[tflite]: Insert nullptr
checks when obtaining tensors.
As part of ongoing refactoring, `tflite::GetInput`, `tflite::GetOutput`, `tflite::GetTemporary` and `tflite::GetIntermediates` will return `nullptr` in some cases. Hence, we insert the `nullptr` checks on all usages. We also insert `nullptr` checks on usages of `tflite::GetVariableInput` and `tflite::GetOptionalInputTensor` but only in the cases where there is no obvious check that `nullptr` is acceptable (that is, we only insert the check for the output of these two functions if the tensor is accessed as if it is always not `nullptr`). PiperOrigin-RevId: 332520146 Change-Id: I405d986cfc653aaafcfdf4162c0acbd46220b921
This commit is contained in:
parent
e84a6501eb
commit
fff2c83262
@ -139,7 +139,9 @@ TfLiteStatus ReluPrepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
ReluOpData* data = static_cast<ReluOpData*>(node->user_data);
|
||||
|
||||
const TfLiteTensor* input = GetInput(context, node, kInputTensor);
|
||||
TF_LITE_ENSURE(context, input != nullptr);
|
||||
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
|
||||
TF_LITE_ENSURE(context, output != nullptr);
|
||||
|
||||
if (input->type == kTfLiteInt8) {
|
||||
CalculateReluOpData<int8_t>(input, output, data);
|
||||
@ -200,6 +202,7 @@ TfLiteStatus Relu6Prepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
Relu6OpData* data = static_cast<Relu6OpData*>(node->user_data);
|
||||
|
||||
const TfLiteTensor* input = GetInput(context, node, kInputTensor);
|
||||
TF_LITE_ENSURE(context, input != nullptr);
|
||||
|
||||
if (input->type == kTfLiteInt8) {
|
||||
data->six_int8 = FloatToAsymmetricQuantizedInt8(6.0f, input->params.scale,
|
||||
|
@ -201,8 +201,11 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
TFLITE_DCHECK(node->builtin_data != nullptr);
|
||||
|
||||
const TfLiteTensor* input1 = GetInput(context, node, kInputTensor1);
|
||||
TF_LITE_ENSURE(context, input1 != nullptr);
|
||||
const TfLiteTensor* input2 = GetInput(context, node, kInputTensor2);
|
||||
TF_LITE_ENSURE(context, input2 != nullptr);
|
||||
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
|
||||
TF_LITE_ENSURE(context, output != nullptr);
|
||||
|
||||
OpData* data = static_cast<OpData*>(node->user_data);
|
||||
auto* params = reinterpret_cast<TfLiteAddParams*>(node->builtin_data);
|
||||
|
@ -30,7 +30,9 @@ constexpr int kOutputTensor = 0;
|
||||
|
||||
TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
const TfLiteTensor* input = GetInput(context, node, kInputTensor);
|
||||
TF_LITE_ENSURE(context, input != nullptr);
|
||||
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
|
||||
TF_LITE_ENSURE(context, output != nullptr);
|
||||
TF_LITE_ENSURE_EQ(context, NumInputs(node), 1);
|
||||
TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
|
||||
TF_LITE_ENSURE_TYPES_EQ(context, input->type, kTfLiteFloat32);
|
||||
|
@ -77,7 +77,9 @@ void Free(TfLiteContext* context, void* buffer) { op_data_counter = 0; }
|
||||
|
||||
TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
const TfLiteTensor* input = GetInput(context, node, kInputTensor);
|
||||
TF_LITE_ENSURE(context, input != nullptr);
|
||||
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
|
||||
TF_LITE_ENSURE(context, output != nullptr);
|
||||
|
||||
TF_LITE_ENSURE(context, input != nullptr);
|
||||
TF_LITE_ENSURE(context, output != nullptr);
|
||||
|
@ -619,7 +619,9 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
OpData* data = static_cast<OpData*>(node->user_data);
|
||||
|
||||
const TfLiteTensor* input1 = GetInput(context, node, kInputTensor1);
|
||||
TF_LITE_ENSURE(context, input1 != nullptr);
|
||||
const TfLiteTensor* input2 = GetInput(context, node, kInputTensor2);
|
||||
TF_LITE_ENSURE(context, input2 != nullptr);
|
||||
|
||||
if (input1->type == kTfLiteUInt8 || input1->type == kTfLiteInt8) {
|
||||
auto input1_offset = -input1->params.zero_point;
|
||||
|
@ -136,8 +136,12 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
const TfLiteConcatenationParams* params =
|
||||
reinterpret_cast<TfLiteConcatenationParams*>(node->builtin_data);
|
||||
|
||||
TfLiteType input_type = GetInput(context, node, 0)->type;
|
||||
TfLiteType output_type = GetOutput(context, node, kOutputTensor)->type;
|
||||
const TfLiteTensor* input_tensor = GetInput(context, node, 0);
|
||||
TF_LITE_ENSURE(context, input_tensor != nullptr);
|
||||
TfLiteType input_type = input_tensor->type;
|
||||
const TfLiteTensor* output_tensor = GetOutput(context, node, kOutputTensor);
|
||||
TF_LITE_ENSURE(context, output_tensor != nullptr);
|
||||
TfLiteType output_type = output_tensor->type;
|
||||
|
||||
// Check activation and input type
|
||||
TF_LITE_ENSURE_EQ(context, params->activation, kTfLiteActNone);
|
||||
@ -156,6 +160,7 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
// Shapes with dimensions >4 are not yet supported with static allocation.
|
||||
for (int i = 0; i < num_inputs; ++i) {
|
||||
const TfLiteTensor* input = GetInput(context, node, i);
|
||||
TF_LITE_ENSURE(context, input != nullptr);
|
||||
int num_dimensions = NumDimensions(input);
|
||||
|
||||
if (num_dimensions > 4) {
|
||||
@ -173,6 +178,7 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
OpData* data = static_cast<OpData*>(node->user_data);
|
||||
|
||||
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
|
||||
TF_LITE_ENSURE(context, output != nullptr);
|
||||
|
||||
switch (output_type) { // Already know in/outtypes are same.
|
||||
case kTfLiteFloat32:
|
||||
@ -199,6 +205,7 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
// Store input scale and zero point values in OpParams:
|
||||
for (int i = 0; i < node->inputs->size; ++i) {
|
||||
const TfLiteTensor* t = GetInput(context, node, i);
|
||||
TF_LITE_ENSURE(context, t != nullptr);
|
||||
input_scales[i] = t->params.scale;
|
||||
input_zero_points[i] = t->params.zero_point;
|
||||
}
|
||||
@ -220,7 +227,9 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
}
|
||||
|
||||
TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
||||
TfLiteType output_type = GetOutput(context, node, kOutputTensor)->type;
|
||||
const TfLiteTensor* output_tensor = GetOutput(context, node, kOutputTensor);
|
||||
TF_LITE_ENSURE(context, output_tensor != nullptr);
|
||||
TfLiteType output_type = output_tensor->type;
|
||||
|
||||
switch (output_type) { // Already know in/outtypes are same.
|
||||
case kTfLiteFloat32:
|
||||
|
@ -97,10 +97,13 @@ TfLiteStatus CalculateOpData(TfLiteContext* context, TfLiteNode* node,
|
||||
// parameters set. This is usually done during quantized training.
|
||||
if (data_type != kTfLiteFloat32) {
|
||||
const TfLiteTensor* input = GetInput(context, node, kInputTensor);
|
||||
TF_LITE_ENSURE(context, input != nullptr);
|
||||
const TfLiteTensor* filter = GetInput(context, node, kFilterTensor);
|
||||
TF_LITE_ENSURE(context, filter != nullptr);
|
||||
const TfLiteTensor* bias =
|
||||
GetOptionalInputTensor(context, node, kBiasTensor);
|
||||
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
|
||||
TF_LITE_ENSURE(context, output != nullptr);
|
||||
int output_channels = filter->dims->data[kConvQuantizedDimension];
|
||||
|
||||
TF_LITE_ENSURE_STATUS(tflite::PopulateConvolutionQuantizationParams(
|
||||
@ -127,8 +130,11 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
const auto params = static_cast<const TfLiteConvParams*>(node->builtin_data);
|
||||
|
||||
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
|
||||
TF_LITE_ENSURE(context, output != nullptr);
|
||||
const TfLiteTensor* input = GetInput(context, node, kInputTensor);
|
||||
TF_LITE_ENSURE(context, input != nullptr);
|
||||
const TfLiteTensor* filter = GetInput(context, node, kFilterTensor);
|
||||
TF_LITE_ENSURE(context, filter != nullptr);
|
||||
|
||||
int input_width = input->dims->data[2];
|
||||
int input_height = input->dims->data[1];
|
||||
|
@ -82,10 +82,13 @@ TfLiteStatus CalculateOpData(TfLiteContext* context, TfLiteNode* node,
|
||||
// parameters set. This is usually done during quantized training.
|
||||
if (data_type != kTfLiteFloat32) {
|
||||
const TfLiteTensor* input = GetInput(context, node, kInputTensor);
|
||||
TF_LITE_ENSURE(context, input != nullptr);
|
||||
const TfLiteTensor* filter = GetInput(context, node, kFilterTensor);
|
||||
TF_LITE_ENSURE(context, filter != nullptr);
|
||||
const TfLiteTensor* bias =
|
||||
GetOptionalInputTensor(context, node, kBiasTensor);
|
||||
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
|
||||
TF_LITE_ENSURE(context, output != nullptr);
|
||||
int num_channels = filter->dims->data[kDepthwiseConvQuantizedDimension];
|
||||
|
||||
return tflite::PopulateConvolutionQuantizationParams(
|
||||
@ -114,8 +117,11 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
OpData* data = static_cast<OpData*>(node->user_data);
|
||||
|
||||
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
|
||||
TF_LITE_ENSURE(context, output != nullptr);
|
||||
const TfLiteTensor* input = GetInput(context, node, kInputTensor);
|
||||
TF_LITE_ENSURE(context, input != nullptr);
|
||||
const TfLiteTensor* filter = GetInput(context, node, kFilterTensor);
|
||||
TF_LITE_ENSURE(context, filter != nullptr);
|
||||
|
||||
const TfLiteType data_type = input->type;
|
||||
int width = SizeOfDimension(input, 2);
|
||||
|
@ -52,7 +52,9 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
|
||||
// TODO(b/140515557): Add cached dequant to improve hybrid model performance.
|
||||
const TfLiteTensor* input = GetInput(context, node, 0);
|
||||
TF_LITE_ENSURE(context, input != nullptr);
|
||||
TfLiteTensor* output = GetOutput(context, node, 0);
|
||||
TF_LITE_ENSURE(context, output != nullptr);
|
||||
|
||||
TF_LITE_ENSURE(context, input->type == kTfLiteUInt8 ||
|
||||
input->type == kTfLiteInt8 ||
|
||||
|
@ -41,7 +41,9 @@ TfLiteStatus GenericPrepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
TF_LITE_ENSURE_EQ(context, NumInputs(node), 1);
|
||||
TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
|
||||
const TfLiteTensor* input = GetInput(context, node, 0);
|
||||
TF_LITE_ENSURE(context, input != nullptr);
|
||||
TfLiteTensor* output = GetOutput(context, node, 0);
|
||||
TF_LITE_ENSURE(context, output != nullptr);
|
||||
TF_LITE_ENSURE_TYPES_EQ(context, input->type, output->type);
|
||||
if (!IsSupportedType(input->type)) {
|
||||
TF_LITE_KERNEL_LOG(context, "Input data type %s (%d) is not supported.",
|
||||
|
@ -93,9 +93,12 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
static_cast<const TfLiteFullyConnectedParams*>(node->builtin_data);
|
||||
|
||||
const TfLiteTensor* input = GetInput(context, node, kInputTensor);
|
||||
TF_LITE_ENSURE(context, input != nullptr);
|
||||
const TfLiteTensor* filter = GetInput(context, node, kWeightsTensor);
|
||||
TF_LITE_ENSURE(context, filter != nullptr);
|
||||
const TfLiteTensor* bias = GetOptionalInputTensor(context, node, kBiasTensor);
|
||||
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
|
||||
TF_LITE_ENSURE(context, output != nullptr);
|
||||
|
||||
TF_LITE_ENSURE_TYPES_EQ(context, input->type, output->type);
|
||||
TF_LITE_ENSURE_MSG(context, input->type == filter->type,
|
||||
|
@ -45,7 +45,9 @@ TfLiteStatus HardSwishPrepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
|
||||
|
||||
const TfLiteTensor* input = GetInput(context, node, kInputTensor);
|
||||
TF_LITE_ENSURE(context, input != nullptr);
|
||||
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
|
||||
TF_LITE_ENSURE(context, output != nullptr);
|
||||
|
||||
if (input->type == kTfLiteUInt8 || input->type == kTfLiteInt8) {
|
||||
HardSwishParams* params = static_cast<HardSwishParams*>(node->user_data);
|
||||
|
@ -50,7 +50,9 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
|
||||
|
||||
const TfLiteTensor* input = GetInput(context, node, kInputTensor);
|
||||
TF_LITE_ENSURE(context, input != nullptr);
|
||||
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
|
||||
TF_LITE_ENSURE(context, output != nullptr);
|
||||
|
||||
TF_LITE_ENSURE(context, NumDimensions(input) <= 4);
|
||||
|
||||
|
@ -43,7 +43,9 @@ struct OpData {
|
||||
TfLiteStatus CalculateArithmeticOpData(TfLiteContext* context, TfLiteNode* node,
|
||||
OpData* data) {
|
||||
const TfLiteTensor* input = GetInput(context, node, kInputTensor);
|
||||
TF_LITE_ENSURE(context, input != nullptr);
|
||||
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
|
||||
TF_LITE_ENSURE(context, output != nullptr);
|
||||
|
||||
TF_LITE_ENSURE_TYPES_EQ(context, input->type, output->type);
|
||||
if (input->type == kTfLiteInt8) {
|
||||
|
@ -51,8 +51,11 @@ struct OpData {
|
||||
TfLiteStatus CalculateOpData(TfLiteContext* context, TfLiteNode* node,
|
||||
TfLiteMulParams* params, OpData* data) {
|
||||
const TfLiteTensor* input1 = GetInput(context, node, kInput1Tensor);
|
||||
TF_LITE_ENSURE(context, input1 != nullptr);
|
||||
const TfLiteTensor* input2 = GetInput(context, node, kInput2Tensor);
|
||||
TF_LITE_ENSURE(context, input2 != nullptr);
|
||||
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
|
||||
TF_LITE_ENSURE(context, output != nullptr);
|
||||
|
||||
TF_LITE_ENSURE_EQ(context, NumInputs(node), 2);
|
||||
TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
|
||||
|
@ -50,10 +50,13 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
|
||||
|
||||
const TfLiteTensor* input = GetInput(context, node, /*index=*/0);
|
||||
TF_LITE_ENSURE(context, input != nullptr);
|
||||
const TfLiteTensor* paddings = GetInput(context, node, /*index=*/1);
|
||||
TF_LITE_ENSURE(context, paddings != nullptr);
|
||||
const TfLiteTensor* constant_values =
|
||||
NumInputs(node) == 3 ? GetInput(context, node, /*index=*/2) : nullptr;
|
||||
TfLiteTensor* output = GetOutput(context, node, /*index=*/0);
|
||||
TF_LITE_ENSURE(context, output != nullptr);
|
||||
|
||||
TF_LITE_ENSURE_EQ(context, input->type, output->type);
|
||||
|
||||
|
@ -222,7 +222,9 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
OpData* data = static_cast<OpData*>(node->user_data);
|
||||
|
||||
const TfLiteTensor* input = GetInput(context, node, kInputTensor);
|
||||
TF_LITE_ENSURE(context, input != nullptr);
|
||||
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
|
||||
TF_LITE_ENSURE(context, output != nullptr);
|
||||
|
||||
TF_LITE_ENSURE_STATUS(CalculateOpData(context, params, input, output, data));
|
||||
|
||||
|
@ -95,8 +95,11 @@ TfLiteStatus PreluPrepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
PreluParams* params = static_cast<PreluParams*>(node->user_data);
|
||||
|
||||
const TfLiteTensor* input = GetInput(context, node, 0);
|
||||
TF_LITE_ENSURE(context, input != nullptr);
|
||||
const TfLiteTensor* alpha = GetInput(context, node, 1);
|
||||
TF_LITE_ENSURE(context, alpha != nullptr);
|
||||
TfLiteTensor* output = GetOutput(context, node, 0);
|
||||
TF_LITE_ENSURE(context, output != nullptr);
|
||||
|
||||
return CalculatePreluParams(input, alpha, output, params);
|
||||
}
|
||||
|
@ -50,7 +50,9 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
|
||||
|
||||
const TfLiteTensor* input = GetInput(context, node, 0);
|
||||
TF_LITE_ENSURE(context, input != nullptr);
|
||||
TfLiteTensor* output = GetOutput(context, node, 0);
|
||||
TF_LITE_ENSURE(context, output != nullptr);
|
||||
|
||||
// TODO(b/128934713): Add support for fixed-point per-channel quantization.
|
||||
// Currently this only support affine per-layer quantization.
|
||||
|
@ -64,6 +64,7 @@ TfLiteStatus PrepareSimple(TfLiteContext* context, TfLiteNode* node) {
|
||||
|
||||
// Validate axis type
|
||||
const TfLiteTensor* axis = GetInput(context, node, 1);
|
||||
TF_LITE_ENSURE(context, axis != nullptr);
|
||||
TF_LITE_ENSURE_TYPES_EQ(context, axis->type, kTfLiteInt32);
|
||||
|
||||
if (input->type == kTfLiteInt8) {
|
||||
|
@ -32,7 +32,9 @@ constexpr int kOutputTensor = 0;
|
||||
|
||||
TfLiteStatus ReshapeOutput(TfLiteContext* context, TfLiteNode* node) {
|
||||
const TfLiteTensor* input = GetInput(context, node, kInputTensor);
|
||||
TF_LITE_ENSURE(context, input != nullptr);
|
||||
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
|
||||
TF_LITE_ENSURE(context, output != nullptr);
|
||||
// Tensorflow's Reshape allows one of the shape components to have the
|
||||
// special -1 value, meaning it will be calculated automatically based on the
|
||||
// input. Here we calculate what that dimension should be so that the number
|
||||
|
@ -30,7 +30,9 @@ constexpr int kOutputTensor = 0;
|
||||
|
||||
TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
const TfLiteTensor* input = GetInput(context, node, kInputTensor);
|
||||
TF_LITE_ENSURE(context, input != nullptr);
|
||||
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
|
||||
TF_LITE_ENSURE(context, output != nullptr);
|
||||
TF_LITE_ENSURE_EQ(context, NumInputs(node), 1);
|
||||
TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
|
||||
TF_LITE_ENSURE_TYPES_EQ(context, input->type, kTfLiteFloat32);
|
||||
|
@ -119,9 +119,11 @@ TfLiteStatus SoftmaxPrepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
TF_LITE_ENSURE_EQ(context, NumInputs(node), 1);
|
||||
TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
|
||||
const TfLiteTensor* input = GetInput(context, node, 0);
|
||||
TF_LITE_ENSURE(context, input != nullptr);
|
||||
TF_LITE_ENSURE(context, NumDimensions(input) >= 1);
|
||||
|
||||
TfLiteTensor* output = GetOutput(context, node, 0);
|
||||
TF_LITE_ENSURE(context, output != nullptr);
|
||||
|
||||
TFLITE_DCHECK(node->user_data != nullptr);
|
||||
SoftmaxParams* data = static_cast<SoftmaxParams*>(node->user_data);
|
||||
|
@ -69,6 +69,7 @@ TfLiteStatus SplitImpl(TfLiteContext* context, TfLiteNode* node,
|
||||
|
||||
TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
const TfLiteTensor* axis = GetInput(context, node, 0);
|
||||
TF_LITE_ENSURE(context, axis != nullptr);
|
||||
|
||||
// Dynamic output tensors are needed if axis tensor is not constant.
|
||||
// But Micro doesn't support dynamic memory allocation, so we only support
|
||||
|
@ -108,8 +108,11 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
auto* params = reinterpret_cast<TfLiteSubParams*>(node->builtin_data);
|
||||
|
||||
const TfLiteTensor* input1 = GetInput(context, node, kInputTensor1);
|
||||
TF_LITE_ENSURE(context, input1 != nullptr);
|
||||
const TfLiteTensor* input2 = GetInput(context, node, kInputTensor2);
|
||||
TF_LITE_ENSURE(context, input2 != nullptr);
|
||||
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
|
||||
TF_LITE_ENSURE(context, output != nullptr);
|
||||
|
||||
TF_LITE_ENSURE_STATUS(
|
||||
CalculateOpData(context, params, input1, input2, output, data));
|
||||
|
@ -366,13 +366,17 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
// [4] = Activation State (variable),
|
||||
// {2, batch_size, memory_size * num_filters}
|
||||
const TfLiteTensor* input = GetInput(context, node, kInputTensor);
|
||||
TF_LITE_ENSURE(context, input != nullptr);
|
||||
const TfLiteTensor* weights_feature =
|
||||
GetInput(context, node, kWeightsFeatureTensor);
|
||||
TF_LITE_ENSURE(context, weights_feature != nullptr);
|
||||
const TfLiteTensor* weights_time =
|
||||
GetInput(context, node, kWeightsTimeTensor);
|
||||
TF_LITE_ENSURE(context, weights_time != nullptr);
|
||||
const TfLiteTensor* bias = GetOptionalInputTensor(context, node, kBiasTensor);
|
||||
const TfLiteTensor* activation_state =
|
||||
GetInput(context, node, kInputActivationStateTensor);
|
||||
TF_LITE_ENSURE(context, activation_state != nullptr);
|
||||
|
||||
// Define input constants based on input tensor definition above:
|
||||
const int rank = params->rank;
|
||||
@ -392,6 +396,7 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
// [0] = float/int8_t, {2, batch_size, num_units}
|
||||
TF_LITE_ENSURE_EQ(context, node->outputs->size, 1);
|
||||
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
|
||||
TF_LITE_ENSURE(context, output != nullptr);
|
||||
TF_LITE_ENSURE_EQ(context, NumDimensions(output), 2);
|
||||
TF_LITE_ENSURE_EQ(context, output->dims->data[0], batch_size);
|
||||
TF_LITE_ENSURE_EQ(context, output->dims->data[1], num_units);
|
||||
|
@ -51,7 +51,9 @@ TfLiteStatus CalculateArithmeticOpData(TfLiteContext* context, TfLiteNode* node,
|
||||
TF_LITE_ENSURE_EQ(context, NumInputs(node), 1);
|
||||
TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
|
||||
const TfLiteTensor* input = GetInput(context, node, kInputTensor);
|
||||
TF_LITE_ENSURE(context, input != nullptr);
|
||||
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
|
||||
TF_LITE_ENSURE(context, output != nullptr);
|
||||
|
||||
TF_LITE_ENSURE_TYPES_EQ(context, input->type, output->type);
|
||||
|
||||
@ -76,6 +78,7 @@ TfLiteStatus TanhPrepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
OpData* data = static_cast<OpData*>(node->user_data);
|
||||
|
||||
const TfLiteTensor* input = GetInput(context, node, kInputTensor);
|
||||
TF_LITE_ENSURE(context, input != nullptr);
|
||||
data->input_zero_point = input->params.zero_point;
|
||||
return CalculateArithmeticOpData(context, node, data);
|
||||
}
|
||||
|
Loading…
x
Reference in New Issue
Block a user