Audit and improve TfLiteType checks in kernels

PiperOrigin-RevId: 316720436
Change-Id: I2032e799ee6afa533b932385c2a70f7621f4ac1b
This commit is contained in:
Sachin Joglekar 2020-06-16 11:20:59 -07:00 committed by TensorFlower Gardener
parent ed557008d6
commit 430b00361b
75 changed files with 258 additions and 248 deletions

View File

@ -205,6 +205,7 @@ void TfLiteFloatArrayFree(TfLiteFloatArray* a);
// the current function, while also reporting the location of the error.
// `a` and `b` may be evaluated more than once, so no side effects or
// extremely expensive computations should be done.
// NOTE: Use TF_LITE_ENSURE_TYPES_EQ if comparing TfLiteTypes.
#define TF_LITE_ENSURE_EQ(context, a, b) \
do { \
if ((a) != (b)) { \

View File

@ -254,7 +254,7 @@ TfLiteStatus GenericPrepare(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
const TfLiteTensor* input = GetInput(context, node, 0);
TfLiteTensor* output = GetOutput(context, node, 0);
TF_LITE_ENSURE_EQ(context, input->type, output->type);
TF_LITE_ENSURE_TYPES_EQ(context, input->type, output->type);
return context->ResizeTensor(context, output,
TfLiteIntArrayCopy(input->dims));
@ -274,7 +274,7 @@ TfLiteStatus ReluPrepare(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
const TfLiteTensor* input = GetInput(context, node, 0);
TfLiteTensor* output = GetOutput(context, node, 0);
TF_LITE_ENSURE_EQ(context, input->type, output->type);
TF_LITE_ENSURE_TYPES_EQ(context, input->type, output->type);
if (input->type == kTfLiteInt8 || input->type == kTfLiteUInt8) {
double real_multiplier = input->params.scale / output->params.scale;
@ -355,7 +355,7 @@ TfLiteStatus LeakyReluPrepare(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
const TfLiteTensor* input = GetInput(context, node, 0);
TfLiteTensor* output = GetOutput(context, node, 0);
TF_LITE_ENSURE_EQ(context, input->type, output->type);
TF_LITE_ENSURE_TYPES_EQ(context, input->type, output->type);
LeakyReluOpData* data = reinterpret_cast<LeakyReluOpData*>(node->user_data);
@ -384,7 +384,7 @@ TfLiteStatus TanhPrepare(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
const TfLiteTensor* input = GetInput(context, node, 0);
TfLiteTensor* output = GetOutput(context, node, 0);
TF_LITE_ENSURE_EQ(context, input->type, output->type);
TF_LITE_ENSURE_TYPES_EQ(context, input->type, output->type);
if (kernel_type == kFixedPointOptimized) {
if (input->type == kTfLiteUInt8 || input->type == kTfLiteInt8) {
@ -469,7 +469,7 @@ TfLiteStatus SigmoidPrepare(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
const TfLiteTensor* input = GetInput(context, node, 0);
TfLiteTensor* output = GetOutput(context, node, 0);
TF_LITE_ENSURE_EQ(context, input->type, output->type);
TF_LITE_ENSURE_TYPES_EQ(context, input->type, output->type);
if (kernel_type == kFixedPointOptimized) {
if (input->type == kTfLiteUInt8 || input->type == kTfLiteInt8) {
@ -569,7 +569,7 @@ TfLiteStatus SoftmaxPrepare(TfLiteContext* context, TfLiteNode* node) {
input->type == kTfLiteUInt8 ||
input->type == kTfLiteInt16);
} else {
TF_LITE_ENSURE_EQ(context, input->type, output->type);
TF_LITE_ENSURE_TYPES_EQ(context, input->type, output->type);
}
TF_LITE_ENSURE(context, NumDimensions(input) >= 1);
@ -632,7 +632,7 @@ TfLiteStatus LogSoftmaxPrepare(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
const TfLiteTensor* input = GetInput(context, node, 0);
TfLiteTensor* output = GetOutput(context, node, 0);
TF_LITE_ENSURE_EQ(context, input->type, output->type);
TF_LITE_ENSURE_TYPES_EQ(context, input->type, output->type);
if (input->type == kTfLiteUInt8 || input->type == kTfLiteInt8) {
TF_LITE_ENSURE_EQ(context, output->params.scale, 16.0 / 256);
@ -671,7 +671,7 @@ TfLiteStatus PreluPrepare(TfLiteContext* context, TfLiteNode* node) {
const TfLiteTensor* alpha = GetInput(context, node, 1);
PreluOpData* data = reinterpret_cast<PreluOpData*>(node->user_data);
TF_LITE_ENSURE_EQ(context, input->type, alpha->type);
TF_LITE_ENSURE_TYPES_EQ(context, input->type, alpha->type);
output->type = input->type;

View File

@ -90,7 +90,7 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
const TfLiteTensor* input2 = GetInput(context, node, kInputTensor2);
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
TF_LITE_ENSURE_EQ(context, input1->type, input2->type);
TF_LITE_ENSURE_TYPES_EQ(context, input1->type, input2->type);
output->type = input2->type;
const bool requires_broadcast = !HaveSameShapes(input1, input2);

View File

@ -41,7 +41,7 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
for (int i = kInputTensor1 + 1; i < num_inputs; ++i) {
const TfLiteTensor* input = GetInput(context, node, i);
TF_LITE_ENSURE(context, HaveSameShapes(input1, input));
TF_LITE_ENSURE_EQ(context, input1->type, input->type);
TF_LITE_ENSURE_TYPES_EQ(context, input1->type, input->type);
}
// Use the first input node's dimension to be the dimension of the output

View File

@ -81,8 +81,8 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_ENSURE_EQ(context, NumDimensions(input), 2);
TF_LITE_ENSURE_EQ(context, output->type, kTfLiteFloat32);
TF_LITE_ENSURE_EQ(context, input->type, output->type);
TF_LITE_ENSURE_TYPES_EQ(context, output->type, kTfLiteFloat32);
TF_LITE_ENSURE_TYPES_EQ(context, input->type, output->type);
TF_LITE_ENSURE(context, params->spectrogram->Initialize(params->window_size,
params->stride));

View File

@ -79,8 +79,9 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
bias->dims->data[0]);
TF_LITE_ENSURE_EQ(context, recurrent_weights->dims->data[1],
bias->dims->data[0]);
TF_LITE_ENSURE_EQ(context, input->type, kTfLiteFloat32);
TF_LITE_ENSURE_EQ(context, input_weights->type, recurrent_weights->type);
TF_LITE_ENSURE_TYPES_EQ(context, input->type, kTfLiteFloat32);
TF_LITE_ENSURE_TYPES_EQ(context, input_weights->type,
recurrent_weights->type);
TF_LITE_ENSURE_EQ(context, NumDimensions(hidden_state), 2);
TF_LITE_ENSURE_EQ(context, hidden_state->dims->data[0], batch_size);
TF_LITE_ENSURE_EQ(context, hidden_state->dims->data[1], num_units);
@ -288,8 +289,8 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
accum_scratch, row_sums, &op_data->compute_row_sums);
}
default:
context->ReportError(context, "Type %d not currently supported.",
input_weights->type);
TF_LITE_KERNEL_LOG(context, "Type %s not currently supported.",
TfLiteTypeGetName(input_weights->type));
return kTfLiteError;
}
return kTfLiteOk;

View File

@ -282,7 +282,7 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
const TfLiteTensor* rhs_data = GetInput(context, node, kInputRHSTensor);
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
TF_LITE_ENSURE_EQ(context, lhs_data->type, kTfLiteFloat32);
TF_LITE_ENSURE_TYPES_EQ(context, lhs_data->type, kTfLiteFloat32);
TF_LITE_ENSURE(context, rhs_data->type == kTfLiteFloat32 ||
rhs_data->type == kTfLiteInt8);
// Support dimensions between 2 and 4, inclusive.

View File

@ -203,7 +203,7 @@ TfLiteStatus CheckLstmTensorDimensionsAndTypes(
TF_LITE_ENSURE_EQ(context, input_to_input_weights->dims->size, 2);
TF_LITE_ENSURE_EQ(context, input_to_input_weights->dims->data[0], n_cell);
TF_LITE_ENSURE_EQ(context, input_to_input_weights->dims->data[1], n_input);
TF_LITE_ENSURE_EQ(context, input_to_input_weights->type,
TF_LITE_ENSURE_TYPES_EQ(context, input_to_input_weights->type,
input_to_forget_weights->type);
}
@ -212,7 +212,7 @@ TfLiteStatus CheckLstmTensorDimensionsAndTypes(
TF_LITE_ENSURE_EQ(context, input_to_cell_weights->dims->size, 2);
TF_LITE_ENSURE_EQ(context, input_to_cell_weights->dims->data[0], n_cell);
TF_LITE_ENSURE_EQ(context, input_to_cell_weights->dims->data[1], n_input);
TF_LITE_ENSURE_EQ(context, input_to_cell_weights->type,
TF_LITE_ENSURE_TYPES_EQ(context, input_to_cell_weights->type,
input_to_forget_weights->type);
const TfLiteTensor* input_to_output_weights =
@ -220,7 +220,7 @@ TfLiteStatus CheckLstmTensorDimensionsAndTypes(
TF_LITE_ENSURE_EQ(context, input_to_output_weights->dims->size, 2);
TF_LITE_ENSURE_EQ(context, input_to_output_weights->dims->data[0], n_cell);
TF_LITE_ENSURE_EQ(context, input_to_output_weights->dims->data[1], n_input);
TF_LITE_ENSURE_EQ(context, input_to_output_weights->type,
TF_LITE_ENSURE_TYPES_EQ(context, input_to_output_weights->type,
input_to_forget_weights->type);
const TfLiteTensor* recurrent_to_input_weights =
@ -231,7 +231,7 @@ TfLiteStatus CheckLstmTensorDimensionsAndTypes(
n_cell);
TF_LITE_ENSURE_EQ(context, recurrent_to_input_weights->dims->data[1],
n_output);
TF_LITE_ENSURE_EQ(context, recurrent_to_input_weights->type,
TF_LITE_ENSURE_TYPES_EQ(context, recurrent_to_input_weights->type,
input_to_forget_weights->type);
}
@ -242,7 +242,7 @@ TfLiteStatus CheckLstmTensorDimensionsAndTypes(
n_cell);
TF_LITE_ENSURE_EQ(context, recurrent_to_forget_weights->dims->data[1],
n_output);
TF_LITE_ENSURE_EQ(context, recurrent_to_forget_weights->type,
TF_LITE_ENSURE_TYPES_EQ(context, recurrent_to_forget_weights->type,
input_to_forget_weights->type);
const TfLiteTensor* recurrent_to_cell_weights =
@ -251,7 +251,7 @@ TfLiteStatus CheckLstmTensorDimensionsAndTypes(
TF_LITE_ENSURE_EQ(context, recurrent_to_cell_weights->dims->data[0], n_cell);
TF_LITE_ENSURE_EQ(context, recurrent_to_cell_weights->dims->data[1],
n_output);
TF_LITE_ENSURE_EQ(context, recurrent_to_cell_weights->type,
TF_LITE_ENSURE_TYPES_EQ(context, recurrent_to_cell_weights->type,
input_to_forget_weights->type);
// We make sure the input-gate's parameters are either both present (regular
@ -268,7 +268,7 @@ TfLiteStatus CheckLstmTensorDimensionsAndTypes(
if (cell_to_input_weights != nullptr) {
TF_LITE_ENSURE_EQ(context, cell_to_input_weights->dims->size, 1);
TF_LITE_ENSURE_EQ(context, cell_to_input_weights->dims->data[0], n_cell);
TF_LITE_ENSURE_EQ(context, cell_to_input_weights->type,
TF_LITE_ENSURE_TYPES_EQ(context, cell_to_input_weights->type,
input_to_forget_weights->type);
}
@ -277,7 +277,7 @@ TfLiteStatus CheckLstmTensorDimensionsAndTypes(
if (cell_to_forget_weights != nullptr) {
TF_LITE_ENSURE_EQ(context, cell_to_forget_weights->dims->size, 1);
TF_LITE_ENSURE_EQ(context, cell_to_forget_weights->dims->data[0], n_cell);
TF_LITE_ENSURE_EQ(context, cell_to_forget_weights->type,
TF_LITE_ENSURE_TYPES_EQ(context, cell_to_forget_weights->type,
input_to_forget_weights->type);
}
@ -286,7 +286,7 @@ TfLiteStatus CheckLstmTensorDimensionsAndTypes(
if (cell_to_output_weights != nullptr) {
TF_LITE_ENSURE_EQ(context, cell_to_output_weights->dims->size, 1);
TF_LITE_ENSURE_EQ(context, cell_to_output_weights->dims->data[0], n_cell);
TF_LITE_ENSURE_EQ(context, cell_to_output_weights->type,
TF_LITE_ENSURE_TYPES_EQ(context, cell_to_output_weights->type,
input_to_forget_weights->type);
}
@ -309,14 +309,14 @@ TfLiteStatus CheckLstmTensorDimensionsAndTypes(
} else {
TF_LITE_ENSURE_EQ(context, input_gate_bias->dims->size, 1);
TF_LITE_ENSURE_EQ(context, input_gate_bias->dims->data[0], n_cell);
TF_LITE_ENSURE_EQ(context, input_gate_bias->type, kTfLiteFloat32);
TF_LITE_ENSURE_TYPES_EQ(context, input_gate_bias->type, kTfLiteFloat32);
}
const TfLiteTensor* forget_gate_bias =
GetInput(context, node, forget_gate_bias_tensor);
TF_LITE_ENSURE_EQ(context, forget_gate_bias->dims->size, 1);
TF_LITE_ENSURE_EQ(context, forget_gate_bias->dims->data[0], n_cell);
TF_LITE_ENSURE_EQ(context, forget_gate_bias->type, kTfLiteFloat32);
TF_LITE_ENSURE_TYPES_EQ(context, forget_gate_bias->type, kTfLiteFloat32);
const TfLiteTensor* cell_bias =
GetInput(context, node, cell_gate_bias_tensor);
@ -328,7 +328,7 @@ TfLiteStatus CheckLstmTensorDimensionsAndTypes(
GetInput(context, node, output_gate_bias_tensor);
TF_LITE_ENSURE_EQ(context, output_gate_bias->dims->size, 1);
TF_LITE_ENSURE_EQ(context, output_gate_bias->dims->data[0], n_cell);
TF_LITE_ENSURE_EQ(context, output_gate_bias->type, kTfLiteFloat32);
TF_LITE_ENSURE_TYPES_EQ(context, output_gate_bias->type, kTfLiteFloat32);
const TfLiteTensor* projection_weights =
GetOptionalInputTensor(context, node, projection_weights_tensor);
@ -336,7 +336,7 @@ TfLiteStatus CheckLstmTensorDimensionsAndTypes(
TF_LITE_ENSURE_EQ(context, projection_weights->dims->size, 2);
TF_LITE_ENSURE_EQ(context, projection_weights->dims->data[0], n_output);
TF_LITE_ENSURE_EQ(context, projection_weights->dims->data[1], n_cell);
TF_LITE_ENSURE_EQ(context, projection_weights->type,
TF_LITE_ENSURE_TYPES_EQ(context, projection_weights->type,
input_to_forget_weights->type);
}
@ -345,7 +345,7 @@ TfLiteStatus CheckLstmTensorDimensionsAndTypes(
if (projection_bias != nullptr) {
TF_LITE_ENSURE_EQ(context, projection_bias->dims->size, 1);
TF_LITE_ENSURE_EQ(context, projection_bias->dims->data[0], n_output);
TF_LITE_ENSURE_EQ(context, projection_bias->type, kTfLiteFloat32);
TF_LITE_ENSURE_TYPES_EQ(context, projection_bias->type, kTfLiteFloat32);
}
// Making sure the projection tensors are consistent:
@ -410,7 +410,7 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
// Inferring batch size, number of outputs and sequence length and
// number of cells from the input tensors.
const TfLiteTensor* input = GetInput(context, node, kInputTensor);
TF_LITE_ENSURE_EQ(context, input->type, kTfLiteFloat32);
TF_LITE_ENSURE_TYPES_EQ(context, input->type, kTfLiteFloat32);
TF_LITE_ENSURE_EQ(context, input->dims->size, 3);
const bool time_major = params->time_major;
const int max_time = time_major ? input->dims->data[0] : input->dims->data[1];
@ -1140,8 +1140,8 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
return kTfLiteOk;
}
default:
context->ReportError(context, "Type %d is not currently supported.",
fw_input_to_output_weights->type);
TF_LITE_KERNEL_LOG(context, "Type %s is not currently supported.",
TfLiteTypeGetName(fw_input_to_output_weights->type));
return kTfLiteError;
}
return kTfLiteOk;

View File

@ -129,7 +129,7 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
// Check all the parameters of tensor match within themselves and match the
// input configuration.
TF_LITE_ENSURE_EQ(context, input->type, kTfLiteFloat32);
TF_LITE_ENSURE_TYPES_EQ(context, input->type, kTfLiteFloat32);
TF_LITE_ENSURE_EQ(context, input->dims->size, 3);
const bool time_major = params->time_major;

View File

@ -32,7 +32,7 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
TF_LITE_ENSURE_EQ(context, NumInputs(node), 1);
TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
TF_LITE_ENSURE_EQ(context, input->type, kTfLiteFloat32);
TF_LITE_ENSURE_TYPES_EQ(context, input->type, kTfLiteFloat32);
output->type = input->type;
TfLiteIntArray* output_size = TfLiteIntArrayCopy(input->dims);
return context->ResizeTensor(context, output, output_size);

View File

@ -81,7 +81,7 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
}
TfLiteTensor* output = GetOutput(context, node, 0);
TF_LITE_ENSURE_EQ(context, output->type, input_type);
TF_LITE_ENSURE_TYPES_EQ(context, output->type, input_type);
if (input_type == kTfLiteInt8) {
// Make sure there is no re-scaling needed for Int8 quantized kernel. This

View File

@ -320,7 +320,7 @@ TfLiteStatus Prepare(KernelType kernel_type, TfLiteContext* context,
TF_LITE_ENSURE(context,
input_type == kTfLiteFloat32 || input_type == kTfLiteUInt8 ||
input_type == kTfLiteInt8 || input_type == kTfLiteInt16);
TF_LITE_ENSURE_EQ(context, output->type, input_type);
TF_LITE_ENSURE_TYPES_EQ(context, output->type, input_type);
const TfLiteTensor* bias = nullptr;
@ -331,15 +331,15 @@ TfLiteStatus Prepare(KernelType kernel_type, TfLiteContext* context,
if (has_bias) {
bias = GetInput(context, node, 2);
if (input_type == kTfLiteUInt8 || input_type == kTfLiteInt8) {
TF_LITE_ENSURE_EQ(context, bias->type, kTfLiteInt32);
TF_LITE_ENSURE_TYPES_EQ(context, bias->type, kTfLiteInt32);
TF_LITE_ENSURE_EQ(context, bias->params.zero_point, 0);
} else if (input_type == kTfLiteInt16) {
TF_LITE_ENSURE_EQ(context, bias->type, kTfLiteInt64);
TF_LITE_ENSURE_TYPES_EQ(context, bias->type, kTfLiteInt64);
TF_LITE_ENSURE_EQ(context, bias->params.zero_point, 0);
TF_LITE_ENSURE_EQ(context, input->params.zero_point, 0);
TF_LITE_ENSURE_EQ(context, output->params.zero_point, 0);
} else {
TF_LITE_ENSURE_EQ(context, bias->type, input_type);
TF_LITE_ENSURE_TYPES_EQ(context, bias->type, input_type);
}
TF_LITE_ENSURE_EQ(context, NumElements(bias), SizeOfDimension(filter, 0));
}
@ -984,7 +984,7 @@ TfLiteStatus EvalImpl(TfLiteContext* context, TfLiteNode* node) {
context, node, params, data, input, filter, bias, output, im2col);
break;
default:
context->ReportError(context, "Type %s currently not supported.",
TF_LITE_KERNEL_LOG(context, "Type %s currently not supported.",
TfLiteTypeGetName(input->type));
return kTfLiteError;
}
@ -1005,8 +1005,8 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
case kTfLiteInt16:
return EvalImpl<kernel_type, kTfLiteInt16>(context, node);
default:
context->ReportError(context, "Type %d not currently supported.",
input->type);
TF_LITE_KERNEL_LOG(context, "Type %s not currently supported.",
TfLiteTypeGetName(input->type));
return kTfLiteError;
}
}

View File

@ -55,7 +55,7 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
data_type == kTfLiteFloat32 || data_type == kTfLiteUInt8 ||
data_type == kTfLiteInt8 || data_type == kTfLiteInt32 ||
data_type == kTfLiteInt64);
TF_LITE_ENSURE_EQ(context, input->type, output->type);
TF_LITE_ENSURE_TYPES_EQ(context, input->type, output->type);
const int block_size = params->block_size;
const int input_height = input->dims->data[1];

View File

@ -122,7 +122,7 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_ENSURE(context,
data_type == kTfLiteFloat32 || data_type == kTfLiteUInt8 ||
data_type == kTfLiteInt8 || data_type == kTfLiteInt16);
TF_LITE_ENSURE_EQ(context, output->type, data_type);
TF_LITE_ENSURE_TYPES_EQ(context, output->type, data_type);
if (!is_hybrid) {
TF_LITE_ENSURE(context,
filter->type == data_type || data_type == kTfLiteInt16);
@ -134,15 +134,15 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
if (hasBias) {
bias = GetInput(context, node, kBiasTensor);
if (data_type == kTfLiteUInt8 || data_type == kTfLiteInt8) {
TF_LITE_ENSURE_EQ(context, bias->type, kTfLiteInt32);
TF_LITE_ENSURE_TYPES_EQ(context, bias->type, kTfLiteInt32);
TF_LITE_ENSURE_EQ(context, bias->params.zero_point, 0);
} else if (data_type == kTfLiteInt16) {
TF_LITE_ENSURE_EQ(context, bias->type, kTfLiteInt64);
TF_LITE_ENSURE_TYPES_EQ(context, bias->type, kTfLiteInt64);
TF_LITE_ENSURE_EQ(context, bias->params.zero_point, 0);
TF_LITE_ENSURE_EQ(context, input->params.zero_point, 0);
TF_LITE_ENSURE_EQ(context, output->params.zero_point, 0);
} else {
TF_LITE_ENSURE_EQ(context, bias->type, data_type);
TF_LITE_ENSURE_TYPES_EQ(context, bias->type, data_type);
}
TF_LITE_ENSURE_EQ(context, NumDimensions(bias), 1);
TF_LITE_ENSURE_EQ(context, SizeOfDimension(filter, 3),
@ -520,9 +520,9 @@ TfLiteStatus EvalImpl(TfLiteContext* context, TfLiteNode* node) {
return EvalHybridPerChannel<kernel_type>(context, node, params, data,
input, filter, bias, output);
} else {
context->ReportError(
context, "Type %d with filter type %d not currently supported.",
input->type, filter->type);
TF_LITE_KERNEL_LOG(
context, "Type %s with filter type %s not currently supported.",
TfLiteTypeGetName(input->type), TfLiteTypeGetName(filter->type));
return kTfLiteError;
}
break;

View File

@ -78,7 +78,7 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
const TfLiteTensor* input2 = GetInput(context, node, kInputTensor2);
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
TF_LITE_ENSURE_EQ(context, input1->type, input2->type);
TF_LITE_ENSURE_TYPES_EQ(context, input1->type, input2->type);
output->type = input2->type;
data->requires_broadcast = !HaveSameShapes(input1, input2);

View File

@ -45,7 +45,7 @@ TfLiteStatus GenericPrepare(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
const TfLiteTensor* input = GetInput(context, node, 0);
TfLiteTensor* output = GetOutput(context, node, 0);
TF_LITE_ENSURE_EQ(context, input->type, output->type);
TF_LITE_ENSURE_TYPES_EQ(context, input->type, output->type);
if (!IsSupportedType(input->type)) {
context->ReportError(context, "Current data type %d is not supported.",
input->type);
@ -60,7 +60,7 @@ inline TfLiteStatus EvalImpl(TfLiteContext* context, TfLiteNode* node,
T func(T), TfLiteType expected_type) {
const TfLiteTensor* input = GetInput(context, node, 0);
TfLiteTensor* output = GetOutput(context, node, 0);
TF_LITE_ENSURE_EQ(context, input->type, expected_type);
TF_LITE_ENSURE_TYPES_EQ(context, input->type, expected_type);
const int64_t num_elements = NumElements(input);
const T* in_data = GetTensorData<T>(input);
T* out_data = GetTensorData<T>(output);

View File

@ -109,7 +109,7 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
// Mark the output as a dynamic tensor.
TfLiteTensor* output = GetOutput(context, node, 0);
TF_LITE_ENSURE_EQ(context, output->type, kTfLiteFloat32);
TF_LITE_ENSURE_TYPES_EQ(context, output->type, kTfLiteFloat32);
output->allocation_type = kTfLiteDynamic;
return kTfLiteOk;

View File

@ -39,7 +39,7 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
TF_LITE_ENSURE_EQ(context, NumInputs(node), 1);
TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
TF_LITE_ENSURE_EQ(context, input->type, kTfLiteFloat32);
TF_LITE_ENSURE_TYPES_EQ(context, input->type, kTfLiteFloat32);
output->type = input->type;
TfLiteIntArray* output_size = TfLiteIntArrayCopy(input->dims);
return context->ResizeTensor(context, output, output_size);

View File

@ -68,7 +68,7 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
const TfLiteTensor* input2 = GetInput(context, node, kInputTensor2);
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
TF_LITE_ENSURE_EQ(context, input1->type, input2->type);
TF_LITE_ENSURE_TYPES_EQ(context, input1->type, input2->type);
const TfLiteType type = input1->type;
switch (type) {

View File

@ -101,13 +101,13 @@ inline TfLiteStatus CheckTypes(TfLiteContext* context,
if (is_quantized) {
if (is_shuffled) {
TF_LITE_ENSURE_EQ(context, input->type, kTfLiteUInt8);
TF_LITE_ENSURE_EQ(context, filter->type, kTfLiteUInt8);
TF_LITE_ENSURE_EQ(context, output->type, kTfLiteInt16);
TF_LITE_ENSURE_TYPES_EQ(context, input->type, kTfLiteUInt8);
TF_LITE_ENSURE_TYPES_EQ(context, filter->type, kTfLiteUInt8);
TF_LITE_ENSURE_TYPES_EQ(context, output->type, kTfLiteInt16);
TF_LITE_ENSURE_EQ(context, is_optional_bias_int, true);
} else if (is_hybrid) {
TF_LITE_ENSURE_EQ(context, input->type, kTfLiteFloat32);
TF_LITE_ENSURE_EQ(context, output->type, kTfLiteFloat32);
TF_LITE_ENSURE_TYPES_EQ(context, input->type, kTfLiteFloat32);
TF_LITE_ENSURE_TYPES_EQ(context, output->type, kTfLiteFloat32);
TF_LITE_ENSURE_EQ(context, is_optional_bias_float, true);
} else {
TF_LITE_ENSURE(context, input->type == kTfLiteUInt8 ||
@ -120,9 +120,9 @@ inline TfLiteStatus CheckTypes(TfLiteContext* context,
}
} else {
// Only float32 is supported currently
TF_LITE_ENSURE_EQ(context, input->type, kTfLiteFloat32);
TF_LITE_ENSURE_EQ(context, output->type, kTfLiteFloat32);
TF_LITE_ENSURE_EQ(context, filter->type, kTfLiteFloat32);
TF_LITE_ENSURE_TYPES_EQ(context, input->type, kTfLiteFloat32);
TF_LITE_ENSURE_TYPES_EQ(context, output->type, kTfLiteFloat32);
TF_LITE_ENSURE_TYPES_EQ(context, filter->type, kTfLiteFloat32);
TF_LITE_ENSURE_EQ(context, is_optional_bias_float, true);
}

View File

@ -88,7 +88,7 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
input->dims->data + input->dims->size);
subgraph->ResizeInputTensor(i, dims);
TfLiteTensor* subgraph_input = subgraph->tensor(subgraph->inputs()[i]);
TF_LITE_ENSURE_EQ(context, input->type, subgraph_input->type);
TF_LITE_ENSURE_TYPES_EQ(context, input->type, subgraph_input->type);
}
// Note: The `Prepare` function is responsible to run `AllocateTensors` on
// both subgraphs. It's intentionally not to break out of the loop when

View File

@ -52,7 +52,7 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_ENSURE(context, output->type == kTfLiteFloat32 ||
output->type == kTfLiteUInt8 ||
output->type == kTfLiteInt8);
TF_LITE_ENSURE_EQ(context, input->type, output->type);
TF_LITE_ENSURE_TYPES_EQ(context, input->type, output->type);
if (output->type == kTfLiteUInt8 || output->type == kTfLiteInt8) {
TF_LITE_ENSURE_EQ(context, output->params.scale, (1. / 128.));
@ -133,8 +133,8 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
depth, GetTensorData<int8>(input),
GetTensorData<int8>(output));
} else {
context->ReportError(context, "Output type is %d, requires float.",
output->type);
TF_LITE_KERNEL_LOG(context, "Output type is %s, requires float.",
TfLiteTypeGetName(output->type));
return kTfLiteError;
}

View File

@ -44,8 +44,8 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_ENSURE_EQ(context, NumDimensions(input), 4);
TF_LITE_ENSURE_EQ(context, output->type, kTfLiteFloat32);
TF_LITE_ENSURE_EQ(context, input->type, output->type);
TF_LITE_ENSURE_TYPES_EQ(context, output->type, kTfLiteFloat32);
TF_LITE_ENSURE_TYPES_EQ(context, input->type, output->type);
TfLiteIntArray* output_size = TfLiteIntArrayCreate(4);
output_size->data[0] = input->dims->data[0];

View File

@ -58,7 +58,7 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
const TfLiteTensor* input2 = GetInput(context, node, kInputTensor2);
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
TF_LITE_ENSURE_EQ(context, input1->type, input2->type);
TF_LITE_ENSURE_TYPES_EQ(context, input1->type, input2->type);
const TfLiteType type = input1->type;
if (type != kTfLiteBool) {

View File

@ -762,7 +762,7 @@ TfLiteStatus CheckInputTensorDimensions(TfLiteContext* context,
TF_LITE_ENSURE_EQ(context, input_to_input_weights->dims->size, 2);
TF_LITE_ENSURE_EQ(context, input_to_input_weights->dims->data[0], n_cell);
TF_LITE_ENSURE_EQ(context, input_to_input_weights->dims->data[1], n_input);
TF_LITE_ENSURE_EQ(context, input_to_input_weights->type,
TF_LITE_ENSURE_TYPES_EQ(context, input_to_input_weights->type,
input_to_forget_weights->type);
}
@ -771,7 +771,7 @@ TfLiteStatus CheckInputTensorDimensions(TfLiteContext* context,
TF_LITE_ENSURE_EQ(context, input_to_cell_weights->dims->size, 2);
TF_LITE_ENSURE_EQ(context, input_to_cell_weights->dims->data[0], n_cell);
TF_LITE_ENSURE_EQ(context, input_to_cell_weights->dims->data[1], n_input);
TF_LITE_ENSURE_EQ(context, input_to_cell_weights->type,
TF_LITE_ENSURE_TYPES_EQ(context, input_to_cell_weights->type,
input_to_forget_weights->type);
const TfLiteTensor* recurrent_to_input_weights =
@ -782,7 +782,7 @@ TfLiteStatus CheckInputTensorDimensions(TfLiteContext* context,
n_cell);
TF_LITE_ENSURE_EQ(context, recurrent_to_input_weights->dims->data[1],
n_output);
TF_LITE_ENSURE_EQ(context, recurrent_to_input_weights->type,
TF_LITE_ENSURE_TYPES_EQ(context, recurrent_to_input_weights->type,
input_to_forget_weights->type);
}
@ -793,7 +793,7 @@ TfLiteStatus CheckInputTensorDimensions(TfLiteContext* context,
n_cell);
TF_LITE_ENSURE_EQ(context, recurrent_to_forget_weights->dims->data[1],
n_output);
TF_LITE_ENSURE_EQ(context, recurrent_to_forget_weights->type,
TF_LITE_ENSURE_TYPES_EQ(context, recurrent_to_forget_weights->type,
input_to_forget_weights->type);
const TfLiteTensor* recurrent_to_cell_weights =
@ -802,7 +802,7 @@ TfLiteStatus CheckInputTensorDimensions(TfLiteContext* context,
TF_LITE_ENSURE_EQ(context, recurrent_to_cell_weights->dims->data[0], n_cell);
TF_LITE_ENSURE_EQ(context, recurrent_to_cell_weights->dims->data[1],
n_output);
TF_LITE_ENSURE_EQ(context, recurrent_to_cell_weights->type,
TF_LITE_ENSURE_TYPES_EQ(context, recurrent_to_cell_weights->type,
input_to_forget_weights->type);
// We make sure the input-gate's parameters are either both present (regular
@ -819,7 +819,7 @@ TfLiteStatus CheckInputTensorDimensions(TfLiteContext* context,
if (cell_to_input_weights) {
TF_LITE_ENSURE_EQ(context, cell_to_input_weights->dims->size, 1);
TF_LITE_ENSURE_EQ(context, cell_to_input_weights->dims->data[0], n_cell);
TF_LITE_ENSURE_EQ(
TF_LITE_ENSURE_TYPES_EQ(
context, cell_to_input_weights->type,
is_integer ? kTfLiteInt16 : input_to_forget_weights->type);
}
@ -829,7 +829,7 @@ TfLiteStatus CheckInputTensorDimensions(TfLiteContext* context,
if (cell_to_forget_weights) {
TF_LITE_ENSURE_EQ(context, cell_to_forget_weights->dims->size, 1);
TF_LITE_ENSURE_EQ(context, cell_to_forget_weights->dims->data[0], n_cell);
TF_LITE_ENSURE_EQ(
TF_LITE_ENSURE_TYPES_EQ(
context, cell_to_forget_weights->type,
is_integer ? kTfLiteInt16 : input_to_forget_weights->type);
}
@ -839,7 +839,7 @@ TfLiteStatus CheckInputTensorDimensions(TfLiteContext* context,
if (cell_to_output_weights) {
TF_LITE_ENSURE_EQ(context, cell_to_output_weights->dims->size, 1);
TF_LITE_ENSURE_EQ(context, cell_to_output_weights->dims->data[0], n_cell);
TF_LITE_ENSURE_EQ(
TF_LITE_ENSURE_TYPES_EQ(
context, cell_to_output_weights->type,
is_integer ? kTfLiteInt16 : input_to_forget_weights->type);
}
@ -863,9 +863,9 @@ TfLiteStatus CheckInputTensorDimensions(TfLiteContext* context,
TF_LITE_ENSURE_EQ(context, input_gate_bias->dims->size, 1);
TF_LITE_ENSURE_EQ(context, input_gate_bias->dims->data[0], n_cell);
if (is_integer) {
TF_LITE_ENSURE_EQ(context, input_gate_bias->type, kTfLiteInt32);
TF_LITE_ENSURE_TYPES_EQ(context, input_gate_bias->type, kTfLiteInt32);
} else {
TF_LITE_ENSURE_EQ(context, input_gate_bias->type, kTfLiteFloat32);
TF_LITE_ENSURE_TYPES_EQ(context, input_gate_bias->type, kTfLiteFloat32);
}
}
@ -874,18 +874,18 @@ TfLiteStatus CheckInputTensorDimensions(TfLiteContext* context,
TF_LITE_ENSURE_EQ(context, forget_gate_bias->dims->size, 1);
TF_LITE_ENSURE_EQ(context, forget_gate_bias->dims->data[0], n_cell);
if (is_integer) {
TF_LITE_ENSURE_EQ(context, forget_gate_bias->type, kTfLiteInt32);
TF_LITE_ENSURE_TYPES_EQ(context, forget_gate_bias->type, kTfLiteInt32);
} else {
TF_LITE_ENSURE_EQ(context, forget_gate_bias->type, kTfLiteFloat32);
TF_LITE_ENSURE_TYPES_EQ(context, forget_gate_bias->type, kTfLiteFloat32);
}
const TfLiteTensor* cell_bias = GetInput(context, node, kCellGateBiasTensor);
TF_LITE_ENSURE_EQ(context, cell_bias->dims->size, 1);
TF_LITE_ENSURE_EQ(context, cell_bias->dims->data[0], n_cell);
if (is_integer) {
TF_LITE_ENSURE_EQ(context, cell_bias->type, kTfLiteInt32);
TF_LITE_ENSURE_TYPES_EQ(context, cell_bias->type, kTfLiteInt32);
} else {
TF_LITE_ENSURE_EQ(context, cell_bias->type, kTfLiteFloat32);
TF_LITE_ENSURE_TYPES_EQ(context, cell_bias->type, kTfLiteFloat32);
}
const TfLiteTensor* output_gate_bias =
@ -893,9 +893,9 @@ TfLiteStatus CheckInputTensorDimensions(TfLiteContext* context,
TF_LITE_ENSURE_EQ(context, output_gate_bias->dims->size, 1);
TF_LITE_ENSURE_EQ(context, output_gate_bias->dims->data[0], n_cell);
if (is_integer) {
TF_LITE_ENSURE_EQ(context, output_gate_bias->type, kTfLiteInt32);
TF_LITE_ENSURE_TYPES_EQ(context, output_gate_bias->type, kTfLiteInt32);
} else {
TF_LITE_ENSURE_EQ(context, output_gate_bias->type, kTfLiteFloat32);
TF_LITE_ENSURE_TYPES_EQ(context, output_gate_bias->type, kTfLiteFloat32);
}
const TfLiteTensor* projection_weights =
@ -904,7 +904,7 @@ TfLiteStatus CheckInputTensorDimensions(TfLiteContext* context,
TF_LITE_ENSURE_EQ(context, projection_weights->dims->size, 2);
TF_LITE_ENSURE_EQ(context, projection_weights->dims->data[0], n_output);
TF_LITE_ENSURE_EQ(context, projection_weights->dims->data[1], n_cell);
TF_LITE_ENSURE_EQ(context, projection_weights->type,
TF_LITE_ENSURE_TYPES_EQ(context, projection_weights->type,
input_to_forget_weights->type);
}
@ -914,9 +914,9 @@ TfLiteStatus CheckInputTensorDimensions(TfLiteContext* context,
TF_LITE_ENSURE_EQ(context, projection_bias->dims->size, 1);
TF_LITE_ENSURE_EQ(context, projection_bias->dims->data[0], n_output);
if (is_integer) {
TF_LITE_ENSURE_EQ(context, projection_bias->type, kTfLiteInt32);
TF_LITE_ENSURE_TYPES_EQ(context, projection_bias->type, kTfLiteInt32);
} else {
TF_LITE_ENSURE_EQ(context, projection_bias->type, kTfLiteFloat32);
TF_LITE_ENSURE_TYPES_EQ(context, projection_bias->type, kTfLiteFloat32);
}
}
@ -940,10 +940,10 @@ TfLiteStatus CheckInputTensorDimensions(TfLiteContext* context,
TF_LITE_ENSURE_EQ(context, input_layer_norm_coefficients->dims->data[0],
n_cell);
if (is_integer) {
TF_LITE_ENSURE_EQ(context, input_layer_norm_coefficients->type,
TF_LITE_ENSURE_TYPES_EQ(context, input_layer_norm_coefficients->type,
kTfLiteInt16);
} else {
TF_LITE_ENSURE_EQ(context, input_layer_norm_coefficients->type,
TF_LITE_ENSURE_TYPES_EQ(context, input_layer_norm_coefficients->type,
kTfLiteFloat32);
}
}
@ -955,10 +955,10 @@ TfLiteStatus CheckInputTensorDimensions(TfLiteContext* context,
TF_LITE_ENSURE_EQ(context, forget_layer_norm_coefficients->dims->data[0],
n_cell);
if (is_integer) {
TF_LITE_ENSURE_EQ(context, forget_layer_norm_coefficients->type,
TF_LITE_ENSURE_TYPES_EQ(context, forget_layer_norm_coefficients->type,
kTfLiteInt16);
} else {
TF_LITE_ENSURE_EQ(context, forget_layer_norm_coefficients->type,
TF_LITE_ENSURE_TYPES_EQ(context, forget_layer_norm_coefficients->type,
kTfLiteFloat32);
}
@ -969,10 +969,10 @@ TfLiteStatus CheckInputTensorDimensions(TfLiteContext* context,
TF_LITE_ENSURE_EQ(context, cell_layer_norm_coefficients->dims->data[0],
n_cell);
if (is_integer) {
TF_LITE_ENSURE_EQ(context, cell_layer_norm_coefficients->type,
TF_LITE_ENSURE_TYPES_EQ(context, cell_layer_norm_coefficients->type,
kTfLiteInt16);
} else {
TF_LITE_ENSURE_EQ(context, cell_layer_norm_coefficients->type,
TF_LITE_ENSURE_TYPES_EQ(context, cell_layer_norm_coefficients->type,
kTfLiteFloat32);
}
@ -983,10 +983,10 @@ TfLiteStatus CheckInputTensorDimensions(TfLiteContext* context,
TF_LITE_ENSURE_EQ(context, output_layer_norm_coefficients->dims->data[0],
n_cell);
if (is_integer) {
TF_LITE_ENSURE_EQ(context, output_layer_norm_coefficients->type,
TF_LITE_ENSURE_TYPES_EQ(context, output_layer_norm_coefficients->type,
kTfLiteInt16);
} else {
TF_LITE_ENSURE_EQ(context, output_layer_norm_coefficients->type,
TF_LITE_ENSURE_TYPES_EQ(context, output_layer_norm_coefficients->type,
kTfLiteFloat32);
}
}

View File

@ -57,7 +57,8 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
OpContext op_context(context, node);
TF_LITE_ENSURE_EQ(context, op_context.input1->type, op_context.input2->type);
TF_LITE_ENSURE_TYPES_EQ(context, op_context.input1->type,
op_context.input2->type);
op_context.output->type = op_context.input1->type;
bool requires_broadcast =

View File

@ -80,9 +80,9 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_ENSURE_EQ(context, NumDimensions(input_wav), 3);
TF_LITE_ENSURE_EQ(context, NumElements(input_rate), 1);
TF_LITE_ENSURE_EQ(context, output->type, kTfLiteFloat32);
TF_LITE_ENSURE_EQ(context, input_wav->type, output->type);
TF_LITE_ENSURE_EQ(context, input_rate->type, kTfLiteInt32);
TF_LITE_ENSURE_TYPES_EQ(context, output->type, kTfLiteFloat32);
TF_LITE_ENSURE_TYPES_EQ(context, input_wav->type, output->type);
TF_LITE_ENSURE_TYPES_EQ(context, input_rate->type, kTfLiteInt32);
TfLiteIntArray* output_size = TfLiteIntArrayCreate(3);
output_size->data[0] = input_wav->dims->data[0];

View File

@ -79,7 +79,7 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
const TfLiteTensor* input2 = GetInput(context, node, kInputTensor2);
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
TF_LITE_ENSURE_EQ(context, input1->type, input2->type);
TF_LITE_ENSURE_TYPES_EQ(context, input1->type, input2->type);
const bool requires_broadcast = !HaveSameShapes(input1, input2);

View File

@ -136,8 +136,8 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
op_context.output->type = op_context.dtype;
break;
default:
context->ReportError(context, "Unknown output data type: %d",
op_context.dtype);
TF_LITE_KERNEL_LOG(context, "Unknown output data type: %s",
TfLiteTypeGetName(op_context.dtype));
return kTfLiteError;
}
@ -148,8 +148,9 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_ENSURE_EQ(context, NumElements(op_context.depth), 1);
TF_LITE_ENSURE_EQ(context, NumElements(op_context.on_value), 1);
TF_LITE_ENSURE_EQ(context, NumElements(op_context.off_value), 1);
TF_LITE_ENSURE_EQ(context, op_context.on_value->type, op_context.dtype);
TF_LITE_ENSURE_EQ(context, op_context.off_value->type, op_context.dtype);
TF_LITE_ENSURE_TYPES_EQ(context, op_context.on_value->type, op_context.dtype);
TF_LITE_ENSURE_TYPES_EQ(context, op_context.off_value->type,
op_context.dtype);
if (!IsConstantTensor(op_context.depth)) {
SetTensorToDynamic(op_context.output);

View File

@ -57,7 +57,7 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
for (int i = 1; i < data->values_count; ++i) {
const TfLiteTensor* input = GetInput(context, node, i);
TF_LITE_ENSURE(context, HaveSameShapes(input0, input));
TF_LITE_ENSURE_EQ(context, input0->type, input->type);
TF_LITE_ENSURE_TYPES_EQ(context, input0->type, input->type);
}
// Resize output. rank R will become rank R + 1
@ -73,7 +73,7 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
}
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
TF_LITE_ENSURE_EQ(context, output->type, input0->type);
TF_LITE_ENSURE_TYPES_EQ(context, output->type, input0->type);
// Guarantee input/output quantization params match as we do not support
// packing quantized tensors.

View File

@ -111,9 +111,10 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
PadContext op_context(context, node);
TF_LITE_ENSURE_EQ(context, op_context.input->type, op_context.output->type);
TF_LITE_ENSURE_TYPES_EQ(context, op_context.input->type,
op_context.output->type);
if (op_context.constant_values != nullptr) {
TF_LITE_ENSURE_EQ(context, op_context.input->type,
TF_LITE_ENSURE_TYPES_EQ(context, op_context.input->type,
op_context.constant_values->type);
}
@ -268,9 +269,8 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
}
} break;
default:
context->ReportError(context,
"Type %d is currently not supported by Pad.",
op_context.input->type);
TF_LITE_KERNEL_LOG(context, "Type %s is currently not supported by Pad.",
TfLiteTypeGetName(op_context.input->type));
return kTfLiteError;
}
#undef TF_LITE_PAD

View File

@ -74,7 +74,7 @@ TfLiteStatus GenericPrepare(TfLiteContext* context, TfLiteNode* node) {
TfLiteTensor* output = GetOutput(context, node, 0);
const TfLiteTensor* input = GetInput(context, node, 0);
TF_LITE_ENSURE_EQ(context, NumDimensions(input), 4);
TF_LITE_ENSURE_EQ(context, input->type, output->type);
TF_LITE_ENSURE_TYPES_EQ(context, input->type, output->type);
int batches = input->dims->data[0];
int height = input->dims->data[1];
@ -98,7 +98,7 @@ TfLiteStatus GenericPrepare(TfLiteContext* context, TfLiteNode* node) {
}
if (pool_type == kL2) {
// We currently don't have a quantized implementation of L2Pool
TF_LITE_ENSURE_EQ(context, input->type, kTfLiteFloat32);
TF_LITE_ENSURE_TYPES_EQ(context, input->type, kTfLiteFloat32);
}
}
@ -387,8 +387,8 @@ TfLiteStatus AverageEval(TfLiteContext* context, TfLiteNode* node) {
output);
break;
default:
context->ReportError(context, "Type %d not currently supported.",
input->type);
TF_LITE_KERNEL_LOG(context, "Type %s not currently supported.",
TfLiteTypeGetName(input->type));
return kTfLiteError;
}
return kTfLiteOk;
@ -418,8 +418,8 @@ TfLiteStatus MaxEval(TfLiteContext* context, TfLiteNode* node) {
output);
break;
default:
context->ReportError(context, "Type %d not currently supported.",
input->type);
TF_LITE_KERNEL_LOG(context, "Type %s not currently supported.",
TfLiteTypeGetName(input->type));
return kTfLiteError;
}
return kTfLiteOk;

View File

@ -58,11 +58,12 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
const TfLiteTensor* input2 = GetInput(context, node, kInputTensor2);
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
TF_LITE_ENSURE_EQ(context, input1->type, input2->type);
TF_LITE_ENSURE_TYPES_EQ(context, input1->type, input2->type);
const TfLiteType type = input1->type;
if (type != kTfLiteInt32 && type != kTfLiteFloat32) {
context->ReportError(context, "Unsupported data type %d.", type);
TF_LITE_KERNEL_LOG(context, "Unsupported data type %s.",
TfLiteTypeGetName(type));
return kTfLiteError;
}
output->type = type;

View File

@ -100,8 +100,8 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
return kTfLiteError;
}
TF_LITE_ENSURE_EQ(context, limit->type, dtype);
TF_LITE_ENSURE_EQ(context, delta->type, dtype);
TF_LITE_ENSURE_TYPES_EQ(context, limit->type, dtype);
TF_LITE_ENSURE_TYPES_EQ(context, delta->type, dtype);
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
output->type = dtype;

View File

@ -58,7 +58,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
TfLiteTensor* variable_tensor = variable->GetTensor();
TfLiteTensor* output = GetOutput(context, node, kOutputValue);
TF_LITE_ENSURE_EQ(context, variable_tensor->type, output->type);
TF_LITE_ENSURE_TYPES_EQ(context, variable_tensor->type, output->type);
TF_LITE_ENSURE_OK(
context, context->ResizeTensor(
context, output, TfLiteIntArrayCopy(variable_tensor->dims)));

View File

@ -235,7 +235,7 @@ TfLiteStatus PrepareSimple(TfLiteContext* context, TfLiteNode* node) {
TfLiteStatus PrepareAny(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_ENSURE_EQ(context, NumInputs(node), 2);
const TfLiteTensor* input = GetInput(context, node, 0);
TF_LITE_ENSURE_EQ(context, input->type, kTfLiteBool);
TF_LITE_ENSURE_TYPES_EQ(context, input->type, kTfLiteBool);
return PrepareSimple(context, node);
}

View File

@ -68,7 +68,7 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
// and the size being 1D tensor with exactly 2 elements.
TF_LITE_ENSURE_EQ(context, NumDimensions(input), 4);
TF_LITE_ENSURE_EQ(context, NumDimensions(size), 1);
TF_LITE_ENSURE_EQ(context, size->type, kTfLiteInt32);
TF_LITE_ENSURE_TYPES_EQ(context, size->type, kTfLiteInt32);
TF_LITE_ENSURE_EQ(context, size->dims->data[0], 2);
output->type = input->type;
@ -122,9 +122,9 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
GetTensorShape(size), GetTensorData<int32>(size),
GetTensorShape(output), GetTensorData<int8_t>(output));
} else {
context->ReportError(context,
"Output type is %d, requires float, uint8 or int8.",
output->type);
TF_LITE_KERNEL_LOG(context,
"Output type is %s, requires float, uint8 or int8.",
TfLiteTypeGetName(output->type));
return kTfLiteError;
}

View File

@ -61,7 +61,7 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
TfLiteIntArray* output_shape = TfLiteIntArrayCopy(input->dims);
TF_LITE_ENSURE_EQ(context, output->type, input->type);
TF_LITE_ENSURE_TYPES_EQ(context, output->type, input->type);
return context->ResizeTensor(context, output, output_shape);
}

View File

@ -58,7 +58,7 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
TfLiteIntArray* output_shape = TfLiteIntArrayCopy(input->dims);
TF_LITE_ENSURE_EQ(context, output->type, input->type);
TF_LITE_ENSURE_TYPES_EQ(context, output->type, input->type);
return context->ResizeTensor(context, output, output_shape);
}

View File

@ -34,7 +34,7 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
TF_LITE_ENSURE_EQ(context, NumInputs(node), 1);
TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
TF_LITE_ENSURE_EQ(context, input->type, kTfLiteFloat32);
TF_LITE_ENSURE_TYPES_EQ(context, input->type, kTfLiteFloat32);
output->type = input->type;
TfLiteIntArray* output_size = TfLiteIntArrayCopy(input->dims);
return context->ResizeTensor(context, output, output_size);

View File

@ -66,8 +66,8 @@ TfLiteStatus SelectPrepare(TfLiteContext* context, TfLiteNode* node) {
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
// Input must be bool.
TF_LITE_ENSURE(context, input_condition->type == kTfLiteBool);
TF_LITE_ENSURE_EQ(context, input_x->type, input_y->type);
TF_LITE_ENSURE_TYPES_EQ(context, input_condition->type, kTfLiteBool);
TF_LITE_ENSURE_TYPES_EQ(context, input_x->type, input_y->type);
output->type = input_x->type;
bool same_shape = HaveSameShapes(input_condition, input_x) &&

View File

@ -48,8 +48,10 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_ENSURE_EQ(context, NumInputs(node), 1);
TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
TF_LITE_ENSURE_EQ(context, GetInput(context, node, 0)->type, kTfLiteString);
TF_LITE_ENSURE_EQ(context, GetOutput(context, node, 0)->type, kTfLiteString);
TF_LITE_ENSURE_TYPES_EQ(context, GetInput(context, node, 0)->type,
kTfLiteString);
TF_LITE_ENSURE_TYPES_EQ(context, GetOutput(context, node, 0)->type,
kTfLiteString);
return kTfLiteOk;
}

View File

@ -100,7 +100,8 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
NumDimensions(op_context.input) >= kInputMinDimensionNum);
TF_LITE_ENSURE(context,
NumDimensions(op_context.input) <= kInputMaxDimensionNum);
TF_LITE_ENSURE_EQ(context, op_context.input->type, op_context.output->type);
TF_LITE_ENSURE_TYPES_EQ(context, op_context.input->type,
op_context.output->type);
if (!IsConstantTensor(op_context.block_shape) ||
!IsConstantTensor(op_context.paddings)) {

View File

@ -55,7 +55,7 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
data_type == kTfLiteFloat32 || data_type == kTfLiteUInt8 ||
data_type == kTfLiteInt8 || data_type == kTfLiteInt32 ||
data_type == kTfLiteInt64);
TF_LITE_ENSURE_EQ(context, input->type, output->type);
TF_LITE_ENSURE_TYPES_EQ(context, input->type, output->type);
const int block_size = params->block_size;
const int input_height = input->dims->data[1];

View File

@ -172,7 +172,7 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
values->type == kTfLiteInt8 ||
values->type == kTfLiteUInt8 ||
values->type == kTfLiteFloat32);
TF_LITE_ENSURE_EQ(context, values->type, default_value->type);
TF_LITE_ENSURE_TYPES_EQ(context, values->type, default_value->type);
// Ensure dimensions match.
TF_LITE_ENSURE_OK(
@ -229,10 +229,10 @@ TfLiteStatus EvalForIndexType(TfLiteContext* context, TfLiteNode* node,
return SparseToDenseImpl<T, int64_t>(context, node);
}
default:
context->ReportError(
TF_LITE_KERNEL_LOG(
context,
"Indice type %d is currently not supported by sparse to dense.",
indices->type);
"Indice type %s is currently not supported by sparse to dense.",
TfLiteTypeGetName(indices->type));
return kTfLiteError;
}
}
@ -253,10 +253,10 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
case kTfLiteUInt8:
return EvalForIndexType<uint8_t>(context, node, indices);
default:
context->ReportError(
TF_LITE_KERNEL_LOG(
context,
"Value type %d is currently not supported by sparse to dense.",
values->type);
"Value type %s is currently not supported by sparse to dense.",
TfLiteTypeGetName(values->type));
return kTfLiteError;
}
}

View File

@ -64,7 +64,7 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
const TfLiteTensor* input2 = GetInput(context, node, kInputTensor2);
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
TF_LITE_ENSURE_EQ(context, input1->type, input2->type);
TF_LITE_ENSURE_TYPES_EQ(context, input1->type, input2->type);
output->type = input2->type;
data->requires_broadcast = !HaveSameShapes(input1, input2);

View File

@ -145,9 +145,9 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_ENSURE_EQ(context, op_context.input->type, op_context.output->type);
// Only INT32 begin/end/strides are supported
// TODO(soroosh) add support for INT64
TF_LITE_ENSURE_EQ(context, op_context.begin->type, kTfLiteInt32);
TF_LITE_ENSURE_EQ(context, op_context.end->type, kTfLiteInt32);
TF_LITE_ENSURE_EQ(context, op_context.strides->type, kTfLiteInt32);
TF_LITE_ENSURE_TYPES_EQ(context, op_context.begin->type, kTfLiteInt32);
TF_LITE_ENSURE_TYPES_EQ(context, op_context.end->type, kTfLiteInt32);
TF_LITE_ENSURE_TYPES_EQ(context, op_context.strides->type, kTfLiteInt32);
TF_LITE_ENSURE_MSG(context, op_context.dims <= 5,
"StridedSlice op only supports 1D-5D input arrays.");
@ -223,10 +223,10 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
}
break;
default:
context->ReportError(context,
"Type %d is currently not supported "
TF_LITE_KERNEL_LOG(context,
"Type %s is currently not supported "
"by StridedSlice.",
op_context.input->type);
TfLiteTypeGetName(op_context.input->type));
return kTfLiteError;
}
#undef TF_LITE_STRIDED_SLICE

View File

@ -206,7 +206,7 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
const TfLiteTensor* input2 = GetInput(context, node, kInputTensor2);
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
TF_LITE_ENSURE_EQ(context, input1->type, input2->type);
TF_LITE_ENSURE_TYPES_EQ(context, input1->type, input2->type);
output->type = input2->type;
data->requires_broadcast = !HaveSameShapes(input1, input2);
@ -287,8 +287,8 @@ void EvalSub(TfLiteContext* context, TfLiteNode* node, TfLiteSubParams* params,
input2, requires_broadcast, output);
break;
default:
TF_LITE_KERNEL_LOG(context, "output type %d is not supported.",
output->type);
TF_LITE_KERNEL_LOG(context, "output type %s is not supported.",
TfLiteTypeGetName(output->type));
}
}

View File

@ -211,7 +211,7 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
const TfLiteTensor* input = GetInput(context, node, kInputTensor);
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
TF_LITE_ENSURE_EQ(context, input->type, output->type);
TF_LITE_ENSURE_TYPES_EQ(context, input->type, output->type);
const TfLiteTensor* multipliers = GetInput(context, node, kInputMultipliers);
// Only int32 and int64 multipliers type is supported.

View File

@ -37,7 +37,7 @@ namespace {
TfLiteStatus ResizeOutput(TfLiteContext* context, TfLiteNode* node) {
const TfLiteTensor* top_k = GetInput(context, node, kInputTopK);
// INT32 number of top results is supported.
TF_LITE_ENSURE_EQ(context, top_k->type, kTfLiteInt32);
TF_LITE_ENSURE_TYPES_EQ(context, top_k->type, kTfLiteInt32);
// Check that the tensor contains only one value.
TF_LITE_ENSURE_EQ(context, NumElements(top_k), 1);
const int32 k = *GetTensorData<int32_t>(top_k);
@ -197,10 +197,10 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
const TfLiteTensor* input = GetInput(context, node, kInputTensor);
TfLiteTensor* output_values = GetOutput(context, node, kOutputValues);
TF_LITE_ENSURE_EQ(context, input->type, output_values->type);
TF_LITE_ENSURE_TYPES_EQ(context, input->type, output_values->type);
const TfLiteTensor* top_k = GetInput(context, node, kInputTopK);
TF_LITE_ENSURE_EQ(context, top_k->type, kTfLiteInt32);
TF_LITE_ENSURE_TYPES_EQ(context, top_k->type, kTfLiteInt32);
// Set output dynamic if the input is not const.
if (IsConstantTensor(top_k)) {
@ -252,9 +252,8 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
output_values->data.i64);
break;
default:
context->ReportError(context,
"Type %d is currently not supported by TopK.",
output_values->type);
TF_LITE_KERNEL_LOG(context, "Type %s is currently not supported by TopK.",
TfLiteTypeGetName(output_values->type));
return kTfLiteError;
}

View File

@ -77,7 +77,8 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
// Ensure validity of input tensor.
TF_LITE_ENSURE_MSG(context, NumDimensions(op_context.input) <= 5,
"Transpose op only supports 1D-5D input arrays.");
TF_LITE_ENSURE_EQ(context, op_context.input->type, op_context.output->type);
TF_LITE_ENSURE_TYPES_EQ(context, op_context.input->type,
op_context.output->type);
if (!IsConstantTensor(op_context.perm)) {
SetTensorToDynamic(op_context.output);
@ -144,9 +145,9 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
}
break;
default:
context->ReportError(context,
"Type %d is currently not supported by Transpose.",
op_context.input->type);
TF_LITE_KERNEL_LOG(context,
"Type %s is currently not supported by Transpose.",
TfLiteTypeGetName(op_context.input->type));
return kTfLiteError;
}
#undef TF_LITE_TRANSPOSE

View File

@ -111,8 +111,8 @@ TfLiteStatus ResizeTensor(TfLiteContext* context,
TfLiteTensor* tensor_to_resize) {
// Currently only support int32 for output shape.
if (shape_tensor->type != kTfLiteInt32) {
context->ReportError(context, "Output shape is %d, not int32.",
shape_tensor->type);
TF_LITE_KERNEL_LOG(context, "Output shape is %s, not int32.",
TfLiteTypeGetName(shape_tensor->type));
return kTfLiteError;
}
@ -176,8 +176,8 @@ TfLiteStatus ResizeCol2ImTensor(TfLiteContext* context,
const TfLiteTensor* input,
TfLiteTensor* col2im) {
if (output_shape->type != kTfLiteInt32) {
context->ReportError(context, "col2im shape is %d, not int32.",
output_shape->type);
TF_LITE_KERNEL_LOG(context, "col2im shape is %s, not int32.",
TfLiteTypeGetName(output_shape->type));
return kTfLiteError;
}
TF_LITE_ENSURE_EQ(context, NumElements(output_shape), 4);
@ -274,7 +274,7 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
bias = GetOptionalInputTensor(context, node, kBiasTensor);
if (bias) {
if (input->type == kTfLiteUInt8 || input->type == kTfLiteInt8) {
TF_LITE_ENSURE_EQ(context, bias->type, kTfLiteInt32);
TF_LITE_ENSURE_TYPES_EQ(context, bias->type, kTfLiteInt32);
if (input->type == kTfLiteInt8) {
TF_LITE_ENSURE_EQ(context, bias->params.zero_point, 0);
}
@ -282,7 +282,7 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_ENSURE_EQ(context, bias->type, kTfLiteInt64);
TF_LITE_ENSURE_EQ(context, bias->params.zero_point, 0);
} else {
TF_LITE_ENSURE_EQ(context, bias->type, input->type);
TF_LITE_ENSURE_TYPES_EQ(context, bias->type, input->type);
}
TF_LITE_ENSURE_EQ(context, NumElements(bias),
SizeOfDimension(weights, 0));
@ -294,9 +294,9 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_ENSURE_EQ(context, input->params.zero_point, 0);
TF_LITE_ENSURE_EQ(context, output->params.zero_point, 0);
} else {
TF_LITE_ENSURE_EQ(context, weights->type, input->type);
TF_LITE_ENSURE_TYPES_EQ(context, weights->type, input->type);
}
TF_LITE_ENSURE_EQ(context, output->type, input->type);
TF_LITE_ENSURE_TYPES_EQ(context, output->type, input->type);
// Ensure that weights and inputs have the same channel dimension.
// Note: TOCO will reorder weights in the following format: OHWI.
TF_LITE_ENSURE_EQ(context, SizeOfDimension(input, 3),

View File

@ -223,7 +223,7 @@ TfLiteStatus CheckInputTensorDimensions(TfLiteContext* context,
TF_LITE_ENSURE_EQ(context, input_layer_norm_coefficients->dims->size, 1);
TF_LITE_ENSURE_EQ(context, input_layer_norm_coefficients->dims->data[0],
n_cell);
TF_LITE_ENSURE_EQ(context, input_layer_norm_coefficients->type,
TF_LITE_ENSURE_TYPES_EQ(context, input_layer_norm_coefficients->type,
kTfLiteFloat32);
}
@ -233,7 +233,7 @@ TfLiteStatus CheckInputTensorDimensions(TfLiteContext* context,
TF_LITE_ENSURE_EQ(context, forget_layer_norm_coefficients->dims->size, 1);
TF_LITE_ENSURE_EQ(context, forget_layer_norm_coefficients->dims->data[0],
n_cell);
TF_LITE_ENSURE_EQ(context, forget_layer_norm_coefficients->type,
TF_LITE_ENSURE_TYPES_EQ(context, forget_layer_norm_coefficients->type,
kTfLiteFloat32);
const TfLiteTensor* cell_layer_norm_coefficients =
@ -242,7 +242,7 @@ TfLiteStatus CheckInputTensorDimensions(TfLiteContext* context,
TF_LITE_ENSURE_EQ(context, cell_layer_norm_coefficients->dims->size, 1);
TF_LITE_ENSURE_EQ(context, cell_layer_norm_coefficients->dims->data[0],
n_cell);
TF_LITE_ENSURE_EQ(context, cell_layer_norm_coefficients->type,
TF_LITE_ENSURE_TYPES_EQ(context, cell_layer_norm_coefficients->type,
kTfLiteFloat32);
const TfLiteTensor* output_layer_norm_coefficients =
@ -251,7 +251,7 @@ TfLiteStatus CheckInputTensorDimensions(TfLiteContext* context,
TF_LITE_ENSURE_EQ(context, output_layer_norm_coefficients->dims->size, 1);
TF_LITE_ENSURE_EQ(context, output_layer_norm_coefficients->dims->data[0],
n_cell);
TF_LITE_ENSURE_EQ(context, output_layer_norm_coefficients->type,
TF_LITE_ENSURE_TYPES_EQ(context, output_layer_norm_coefficients->type,
kTfLiteFloat32);
}
@ -290,7 +290,7 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
// Inferring batch size, number of outputs and sequence length and
// number of cells from the input tensors.
const TfLiteTensor* input = GetInput(context, node, lstm::full::kInputTensor);
TF_LITE_ENSURE_EQ(context, input->type, kTfLiteFloat32);
TF_LITE_ENSURE_TYPES_EQ(context, input->type, kTfLiteFloat32);
TF_LITE_ENSURE(context, input->dims->size > 1);
const auto* params =
reinterpret_cast<TfLiteUnidirectionalSequenceLSTMParams*>(
@ -659,8 +659,8 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
CpuBackendContext::GetFromContext(context));
}
default:
context->ReportError(context, "Type %d is not currently supported.",
input_to_output_weights->type);
TF_LITE_KERNEL_LOG(context, "Type %s is not currently supported.",
TfLiteTypeGetName(input_to_output_weights->type));
return kTfLiteError;
}
return kTfLiteOk;

View File

@ -85,8 +85,9 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
bias->dims->data[0]);
TF_LITE_ENSURE_EQ(context, recurrent_weights->dims->data[1],
bias->dims->data[0]);
TF_LITE_ENSURE_EQ(context, input->type, kTfLiteFloat32);
TF_LITE_ENSURE_EQ(context, input_weights->type, recurrent_weights->type);
TF_LITE_ENSURE_TYPES_EQ(context, input->type, kTfLiteFloat32);
TF_LITE_ENSURE_TYPES_EQ(context, input_weights->type,
recurrent_weights->type);
TF_LITE_ENSURE_EQ(context, NumDimensions(hidden_state), 2);
TF_LITE_ENSURE_EQ(context, hidden_state->dims->data[0], batch_size);
TF_LITE_ENSURE_EQ(context, hidden_state->dims->data[1], num_units);
@ -364,8 +365,8 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
accum_scratch, row_sums, &op_data->compute_row_sums);
}
default:
context->ReportError(context, "Type %d not currently supported.",
input_weights->type);
TF_LITE_KERNEL_LOG(context, "Type %d not currently supported.",
TfLiteTypeGetName(input_weights->type));
return kTfLiteError;
}
return kTfLiteOk;

View File

@ -68,7 +68,7 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
for (int i = 0; i < data->num; ++i) {
TfLiteIntArray* copied_output_shape = TfLiteIntArrayCopy(output_shape);
TfLiteTensor* output = GetOutput(context, node, i);
TF_LITE_ENSURE_EQ(context, output->type, input->type);
TF_LITE_ENSURE_TYPES_EQ(context, output->type, input->type);
// Guarantee input/output quantization params match as we do not support
// rescaling of unpacked quantized tensors.
TF_LITE_ENSURE_EQ(context, input->params.zero_point,

View File

@ -90,7 +90,7 @@ TfLiteStatus CopyTensorsData(TfLiteContext* context, Subgraph* src_subgraph,
TfLiteStatus CheckCondOutput(TfLiteContext* context,
const TfLiteTensor* cond_output) {
// The condition output must be a single boolean value.
TF_LITE_ENSURE_EQ(context, cond_output->type, kTfLiteBool);
TF_LITE_ENSURE_TYPES_EQ(context, cond_output->type, kTfLiteBool);
if (cond_output->dims->size == 0) {
// It's okay if it's a 0D scalar.
return kTfLiteOk;
@ -179,7 +179,7 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
body_subgraph->tensor(body_subgraph->inputs()[i]);
TfLiteTensor* body_output =
body_subgraph->tensor(body_subgraph->outputs()[i]);
TF_LITE_ENSURE_EQ(context, body_input->type, body_output->type);
TF_LITE_ENSURE_TYPES_EQ(context, body_input->type, body_output->type);
// TODO(ycling): Support dynamic sized body subgraph.
TF_LITE_ENSURE(context, !IsDynamicTensor(body_output));

View File

@ -111,7 +111,7 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
TF_LITE_ENSURE(context, data != nullptr);
TF_LITE_ENSURE_EQ(context, input->type, output->type);
TF_LITE_ENSURE_TYPES_EQ(context, input->type, output->type);
TF_LITE_ENSURE_MSG(context, input->type == filter->type,
"Hybrid models are not supported on TFLite Micro.");

View File

@ -32,8 +32,8 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
TF_LITE_ENSURE_EQ(context, NumInputs(node), 1);
TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
TF_LITE_ENSURE_EQ(context, input->type, kTfLiteFloat32);
TF_LITE_ENSURE_EQ(context, output->type, input->type);
TF_LITE_ENSURE_TYPES_EQ(context, input->type, kTfLiteFloat32);
TF_LITE_ENSURE_TYPES_EQ(context, output->type, input->type);
TF_LITE_ENSURE_EQ(context, output->bytes, input->bytes);
TF_LITE_ENSURE_EQ(context, output->dims->size, input->dims->size);
for (int i = 0; i < output->dims->size; ++i) {

View File

@ -89,10 +89,10 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_ENSURE_EQ(context, 1, input->dims->data[2]);
TF_LITE_ENSURE_EQ(context, output->dims->data[3], input->dims->data[3]);
TF_LITE_ENSURE_EQ(context, input->type, output->type);
TF_LITE_ENSURE_TYPES_EQ(context, input->type, output->type);
// The circular buffer custom operator currently only supports int8.
TF_LITE_ENSURE_EQ(context, input->type, kTfLiteInt8);
TF_LITE_ENSURE_TYPES_EQ(context, input->type, kTfLiteInt8);
// TODO(b/132070898): Use statically slotted OpData structures until a
// scratch memory API is ready.

View File

@ -81,7 +81,7 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
const TfLiteTensor* input = GetInput(context, node, kInputTensor);
const TfLiteTensor* filter = GetInput(context, node, kWeightsTensor);
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
TF_LITE_ENSURE_EQ(context, input->type, output->type);
TF_LITE_ENSURE_TYPES_EQ(context, input->type, output->type);
TF_LITE_ENSURE_MSG(context, input->type == filter->type,
"Hybrid models are not supported on TFLite Micro.");
#if defined(__ARM_FEATURE_DSP)

View File

@ -48,7 +48,7 @@ TfLiteStatus CalculateOpData(TfLiteContext* context, TfLiteNode* node,
TF_LITE_ENSURE_EQ(context, NumInputs(node), 2);
TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
TF_LITE_ENSURE_EQ(context, input1->type, input2->type);
TF_LITE_ENSURE_TYPES_EQ(context, input1->type, input2->type);
if (output->type == kTfLiteUInt8 || output->type == kTfLiteInt8) {
TF_LITE_ENSURE_STATUS(CalculateActivationRangeQuantized(

View File

@ -40,7 +40,7 @@ TfLiteStatus GenericPrepare(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
const TfLiteTensor* input = GetInput(context, node, 0);
TfLiteTensor* output = GetOutput(context, node, 0);
TF_LITE_ENSURE_EQ(context, input->type, output->type);
TF_LITE_ENSURE_TYPES_EQ(context, input->type, output->type);
if (!IsSupportedType(input->type)) {
TF_LITE_KERNEL_LOG(context, "Input data type %s (%d) is not supported.",
TfLiteTypeGetName(input->type), input->type);
@ -54,7 +54,7 @@ inline TfLiteStatus EvalImpl(TfLiteContext* context, TfLiteNode* node,
T func(T), TfLiteType expected_type) {
const TfLiteTensor* input = GetInput(context, node, 0);
TfLiteTensor* output = GetOutput(context, node, 0);
TF_LITE_ENSURE_EQ(context, input->type, expected_type);
TF_LITE_ENSURE_TYPES_EQ(context, input->type, expected_type);
const int64_t num_elements = NumElements(input);
const T* in_data = GetTensorData<T>(input);
T* out_data = GetTensorData<T>(output);

View File

@ -29,7 +29,7 @@ constexpr int kOutputTensor = 0;
TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
const TfLiteTensor* input = GetInput(context, node, kInputTensor);
TF_LITE_ENSURE_EQ(context, input->type, kTfLiteFloat32);
TF_LITE_ENSURE_TYPES_EQ(context, input->type, kTfLiteFloat32);
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
reference_ops::Floor(GetTensorShape(input), GetTensorData<float>(input),
GetTensorShape(output), GetTensorData<float>(output));

View File

@ -93,7 +93,7 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
const TfLiteTensor* bias = GetOptionalInputTensor(context, node, kBiasTensor);
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
TF_LITE_ENSURE_EQ(context, input->type, output->type);
TF_LITE_ENSURE_TYPES_EQ(context, input->type, output->type);
TF_LITE_ENSURE_MSG(context, input->type == filter->type,
"Hybrid models are not supported on TFLite Micro.");

View File

@ -48,7 +48,7 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_ENSURE(context, output->type == kTfLiteFloat32 ||
output->type == kTfLiteUInt8 ||
output->type == kTfLiteInt8);
TF_LITE_ENSURE_EQ(context, input->type, output->type);
TF_LITE_ENSURE_TYPES_EQ(context, input->type, output->type);
if (output->type == kTfLiteUInt8 || output->type == kTfLiteInt8) {
TF_LITE_ENSURE_EQ(context, output->params.scale, (1. / 128.));
@ -118,8 +118,8 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
depth, GetTensorData<int8>(input),
GetTensorData<int8>(output));
} else {
TF_LITE_KERNEL_LOG(context, "Output type is %d, requires float.",
output->type);
TF_LITE_KERNEL_LOG(context, "Output type is %s, requires float.",
TfLiteTypeGetName(output->type));
return kTfLiteError;
}

View File

@ -44,7 +44,7 @@ TfLiteStatus CalculateArithmeticOpData(TfLiteContext* context, TfLiteNode* node,
const TfLiteTensor* input = GetInput(context, node, kInputTensor);
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
TF_LITE_ENSURE_EQ(context, input->type, output->type);
TF_LITE_ENSURE_TYPES_EQ(context, input->type, output->type);
if (input->type == kTfLiteInt8) {
TF_LITE_ENSURE_EQ(context, output->params.zero_point,
std::numeric_limits<int8_t>::min());

View File

@ -48,7 +48,7 @@ TfLiteStatus CalculateOpData(TfLiteContext* context, TfLiteNode* node,
TF_LITE_ENSURE_EQ(context, NumInputs(node), 2);
TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
TF_LITE_ENSURE_EQ(context, input1->type, input2->type);
TF_LITE_ENSURE_TYPES_EQ(context, input1->type, input2->type);
if (output->type == kTfLiteUInt8 || output->type == kTfLiteInt8) {
TF_LITE_ENSURE_STATUS(CalculateActivationRangeQuantized(

View File

@ -61,7 +61,7 @@ TfLiteStatus ReshapeOutput(TfLiteContext* context, TfLiteNode* node) {
num_output_elements *= output_shape->data[stretch_dim];
}
TF_LITE_ENSURE_EQ(context, input->type, output->type);
TF_LITE_ENSURE_TYPES_EQ(context, input->type, output->type);
TF_LITE_ENSURE_EQ(context, num_input_elements, num_output_elements);
return kTfLiteOk;
}

View File

@ -32,8 +32,8 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
TF_LITE_ENSURE_EQ(context, NumInputs(node), 1);
TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
TF_LITE_ENSURE_EQ(context, input->type, kTfLiteFloat32);
TF_LITE_ENSURE_EQ(context, output->type, input->type);
TF_LITE_ENSURE_TYPES_EQ(context, input->type, kTfLiteFloat32);
TF_LITE_ENSURE_TYPES_EQ(context, output->type, input->type);
TF_LITE_ENSURE_EQ(context, output->bytes, input->bytes);
TF_LITE_ENSURE_EQ(context, output->dims->size, input->dims->size);
for (int i = 0; i < output->dims->size; ++i) {

View File

@ -419,7 +419,7 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_ENSURE_EQ(context, bias->type, kTfLiteInt32);
}
TF_LITE_ENSURE_EQ(context, output->type, kTfLiteInt8);
TF_LITE_ENSURE_TYPES_EQ(context, output->type, kTfLiteInt8);
const auto* input_params =
reinterpret_cast<TfLiteAffineQuantization*>(input->quantization.params);
@ -467,7 +467,7 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
if (bias != nullptr) {
TF_LITE_ENSURE_EQ(context, bias->type, kTfLiteFloat32);
}
TF_LITE_ENSURE_EQ(context, output->type, kTfLiteFloat32);
TF_LITE_ENSURE_TYPES_EQ(context, output->type, kTfLiteFloat32);
TFLITE_DCHECK(node->user_data != nullptr);
OpData* data = static_cast<OpData*>(node->user_data);

View File

@ -44,7 +44,7 @@ TfLiteStatus CalculateArithmeticOpData(TfLiteContext* context, TfLiteNode* node,
const TfLiteTensor* input = GetInput(context, node, kInputTensor);
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
TF_LITE_ENSURE_EQ(context, input->type, output->type);
TF_LITE_ENSURE_TYPES_EQ(context, input->type, output->type);
if (input->type == kTfLiteInt8) {
TF_LITE_ENSURE_EQ(context, output->params.zero_point, 0);

View File

@ -51,7 +51,7 @@ constexpr int kOutputTensor = 0;
TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
const TfLiteTensor* input = GetInput(context, node, kInputTensor);
TF_LITE_ENSURE_EQ(context, input->type, kTfLiteFloat32);
TF_LITE_ENSURE_TYPES_EQ(context, input->type, kTfLiteFloat32);
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
int err;
const float* inp_data_ptr;

View File

@ -432,7 +432,7 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
// EvalIntegerSVDF().
// Validate output tensor:
TF_LITE_ENSURE_EQ(context, output->type, kTfLiteInt8);
TF_LITE_ENSURE_TYPES_EQ(context, output->type, kTfLiteInt8);
} else {
TF_LITE_ENSURE_EQ(context, node->inputs->size, 5);
@ -457,7 +457,7 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
// TODO(b/132070898): Use input tensor as variable until scratch tensor
// allocation has been implemented.
// TF_LITE_ENSURE_EQ(context, node->temporaries->size, 1);
TF_LITE_ENSURE_EQ(context, output->type, kTfLiteFloat32);
TF_LITE_ENSURE_TYPES_EQ(context, output->type, kTfLiteFloat32);
}
return kTfLiteOk;

View File

@ -348,7 +348,7 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_ENSURE_EQ(context, activation_state->type, kTfLiteInt16);
// Validate output tensor:
TF_LITE_ENSURE_EQ(context, output->type, kTfLiteInt8);
TF_LITE_ENSURE_TYPES_EQ(context, output->type, kTfLiteInt8);
// Calculate effective scales.
auto* input_params =

View File

@ -205,6 +205,7 @@ void TfLiteFloatArrayFree(TfLiteFloatArray* a);
// the current function, while also reporting the location of the error.
// `a` and `b` may be evaluated more than once, so no side effects or
// extremely expensive computations should be done.
// NOTE: Use TF_LITE_ENSURE_TYPES_EQ if comparing TfLiteTypes.
#define TF_LITE_ENSURE_EQ(context, a, b) \
do { \
if ((a) != (b)) { \