[tflite]: Insert nullptr
checks when obtaining tensors.
As part of ongoing refactoring, `tflite::GetInput`, `tflite::GetOutput`, `tflite::GetTemporary` and `tflite::GetIntermediates` will return `nullptr` in some cases. Hence, we insert the `nullptr` checks on all usages. We also insert `nullptr` checks on usages of `tflite::GetVariableInput` and `tflite::GetOptionalInputTensor` but only in the cases where there is no obvious check that `nullptr` is acceptable (that is, we only insert the check for the output of these two functions if the tensor is accessed as if it is always not `nullptr`). PiperOrigin-RevId: 332521299 Change-Id: I29af455bcb48d0b92e58132d951a3badbd772d56
This commit is contained in:
parent
fff2c83262
commit
1970c2158b
@ -252,8 +252,10 @@ void* HardSwishInit(TfLiteContext* context, const char* buffer, size_t length) {
|
||||
TfLiteStatus GenericPrepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
TF_LITE_ENSURE_EQ(context, NumInputs(node), 1);
|
||||
TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
|
||||
const TfLiteTensor* input = GetInput(context, node, 0);
|
||||
TfLiteTensor* output = GetOutput(context, node, 0);
|
||||
const TfLiteTensor* input;
|
||||
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 0, &input));
|
||||
TfLiteTensor* output;
|
||||
TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, 0, &output));
|
||||
TF_LITE_ENSURE_TYPES_EQ(context, input->type, output->type);
|
||||
|
||||
return context->ResizeTensor(context, output,
|
||||
@ -272,8 +274,10 @@ TfLiteStatus ReluPrepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
ReluOpData* data = reinterpret_cast<ReluOpData*>(node->user_data);
|
||||
TF_LITE_ENSURE_EQ(context, NumInputs(node), 1);
|
||||
TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
|
||||
const TfLiteTensor* input = GetInput(context, node, 0);
|
||||
TfLiteTensor* output = GetOutput(context, node, 0);
|
||||
const TfLiteTensor* input;
|
||||
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 0, &input));
|
||||
TfLiteTensor* output;
|
||||
TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, 0, &output));
|
||||
TF_LITE_ENSURE_TYPES_EQ(context, input->type, output->type);
|
||||
|
||||
if (input->type == kTfLiteInt8 || input->type == kTfLiteUInt8) {
|
||||
@ -300,12 +304,14 @@ void HardSwishFree(TfLiteContext* context, void* buffer) {
|
||||
|
||||
TfLiteStatus HardSwishPrepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
TF_LITE_ENSURE_STATUS(GenericPrepare(context, node));
|
||||
TfLiteTensor* output = GetOutput(context, node, 0);
|
||||
TfLiteTensor* output;
|
||||
TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, 0, &output));
|
||||
|
||||
if (output->type == kTfLiteUInt8 || output->type == kTfLiteInt8) {
|
||||
HardSwishData* data = static_cast<HardSwishData*>(node->user_data);
|
||||
HardSwishParams* params = &data->params;
|
||||
const TfLiteTensor* input = GetInput(context, node, 0);
|
||||
const TfLiteTensor* input;
|
||||
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 0, &input));
|
||||
params->input_zero_point = input->params.zero_point;
|
||||
params->output_zero_point = output->params.zero_point;
|
||||
const float input_scale = input->params.scale;
|
||||
@ -337,8 +343,10 @@ TfLiteStatus HardSwishPrepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
TfLiteStatus LeakyReluPrepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
TF_LITE_ENSURE_EQ(context, NumInputs(node), 1);
|
||||
TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
|
||||
const TfLiteTensor* input = GetInput(context, node, 0);
|
||||
TfLiteTensor* output = GetOutput(context, node, 0);
|
||||
const TfLiteTensor* input;
|
||||
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 0, &input));
|
||||
TfLiteTensor* output;
|
||||
TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, 0, &output));
|
||||
TF_LITE_ENSURE_TYPES_EQ(context, input->type, output->type);
|
||||
|
||||
LeakyReluOpData* data = reinterpret_cast<LeakyReluOpData*>(node->user_data);
|
||||
@ -366,8 +374,10 @@ TfLiteStatus TanhPrepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
|
||||
TF_LITE_ENSURE_EQ(context, NumInputs(node), 1);
|
||||
TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
|
||||
const TfLiteTensor* input = GetInput(context, node, 0);
|
||||
TfLiteTensor* output = GetOutput(context, node, 0);
|
||||
const TfLiteTensor* input;
|
||||
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 0, &input));
|
||||
TfLiteTensor* output;
|
||||
TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, 0, &output));
|
||||
TF_LITE_ENSURE_TYPES_EQ(context, input->type, output->type);
|
||||
|
||||
if (kernel_type == kFixedPointOptimized) {
|
||||
@ -451,8 +461,10 @@ TfLiteStatus SigmoidPrepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
|
||||
TF_LITE_ENSURE_EQ(context, NumInputs(node), 1);
|
||||
TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
|
||||
const TfLiteTensor* input = GetInput(context, node, 0);
|
||||
TfLiteTensor* output = GetOutput(context, node, 0);
|
||||
const TfLiteTensor* input;
|
||||
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 0, &input));
|
||||
TfLiteTensor* output;
|
||||
TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, 0, &output));
|
||||
TF_LITE_ENSURE_TYPES_EQ(context, input->type, output->type);
|
||||
|
||||
if (kernel_type == kFixedPointOptimized) {
|
||||
@ -546,8 +558,10 @@ TfLiteStatus SoftmaxPrepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
|
||||
TF_LITE_ENSURE_EQ(context, NumInputs(node), 1);
|
||||
TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
|
||||
const TfLiteTensor* input = GetInput(context, node, 0);
|
||||
TfLiteTensor* output = GetOutput(context, node, 0);
|
||||
const TfLiteTensor* input;
|
||||
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 0, &input));
|
||||
TfLiteTensor* output;
|
||||
TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, 0, &output));
|
||||
if (output->type == kTfLiteInt16) {
|
||||
TF_LITE_ENSURE(context, input->type == kTfLiteInt8 ||
|
||||
input->type == kTfLiteUInt8 ||
|
||||
@ -614,8 +628,10 @@ TfLiteStatus LogSoftmaxPrepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
|
||||
TF_LITE_ENSURE_EQ(context, NumInputs(node), 1);
|
||||
TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
|
||||
const TfLiteTensor* input = GetInput(context, node, 0);
|
||||
TfLiteTensor* output = GetOutput(context, node, 0);
|
||||
const TfLiteTensor* input;
|
||||
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 0, &input));
|
||||
TfLiteTensor* output;
|
||||
TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, 0, &output));
|
||||
TF_LITE_ENSURE_TYPES_EQ(context, input->type, output->type);
|
||||
|
||||
if (input->type == kTfLiteUInt8 || input->type == kTfLiteInt8) {
|
||||
@ -650,9 +666,12 @@ TfLiteStatus LogSoftmaxPrepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
TfLiteStatus PreluPrepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
TF_LITE_ENSURE_EQ(context, NumInputs(node), 2);
|
||||
TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
|
||||
const TfLiteTensor* input = GetInput(context, node, 0);
|
||||
TfLiteTensor* output = GetOutput(context, node, 0);
|
||||
const TfLiteTensor* alpha = GetInput(context, node, 1);
|
||||
const TfLiteTensor* input;
|
||||
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 0, &input));
|
||||
TfLiteTensor* output;
|
||||
TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, 0, &output));
|
||||
const TfLiteTensor* alpha;
|
||||
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 1, &alpha));
|
||||
PreluOpData* data = reinterpret_cast<PreluOpData*>(node->user_data);
|
||||
|
||||
TF_LITE_ENSURE_TYPES_EQ(context, input->type, alpha->type);
|
||||
@ -704,8 +723,10 @@ TfLiteStatus PreluPrepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
}
|
||||
|
||||
TfLiteStatus ReluEval(TfLiteContext* context, TfLiteNode* node) {
|
||||
const TfLiteTensor* input = GetInput(context, node, 0);
|
||||
TfLiteTensor* output = GetOutput(context, node, 0);
|
||||
const TfLiteTensor* input;
|
||||
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 0, &input));
|
||||
TfLiteTensor* output;
|
||||
TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, 0, &output));
|
||||
const ReluOpData* data = reinterpret_cast<ReluOpData*>(node->user_data);
|
||||
switch (input->type) {
|
||||
case kTfLiteFloat32: {
|
||||
@ -732,8 +753,10 @@ TfLiteStatus ReluEval(TfLiteContext* context, TfLiteNode* node) {
|
||||
}
|
||||
|
||||
TfLiteStatus Relu1Eval(TfLiteContext* context, TfLiteNode* node) {
|
||||
const TfLiteTensor* input = GetInput(context, node, 0);
|
||||
TfLiteTensor* output = GetOutput(context, node, 0);
|
||||
const TfLiteTensor* input;
|
||||
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 0, &input));
|
||||
TfLiteTensor* output;
|
||||
TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, 0, &output));
|
||||
const ReluOpData* data = reinterpret_cast<ReluOpData*>(node->user_data);
|
||||
switch (input->type) {
|
||||
case kTfLiteFloat32: {
|
||||
@ -763,8 +786,10 @@ template <KernelType kernel_type>
|
||||
TfLiteStatus HardSwishEval(TfLiteContext* context, TfLiteNode* node) {
|
||||
HardSwishData* data = static_cast<HardSwishData*>(node->user_data);
|
||||
|
||||
const TfLiteTensor* input = GetInput(context, node, 0);
|
||||
TfLiteTensor* output = GetOutput(context, node, 0);
|
||||
const TfLiteTensor* input;
|
||||
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 0, &input));
|
||||
TfLiteTensor* output;
|
||||
TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, 0, &output));
|
||||
switch (input->type) {
|
||||
case kTfLiteFloat32: {
|
||||
if (kernel_type == kReference) {
|
||||
@ -814,8 +839,10 @@ TfLiteStatus HardSwishEval(TfLiteContext* context, TfLiteNode* node) {
|
||||
}
|
||||
|
||||
TfLiteStatus Relu6Eval(TfLiteContext* context, TfLiteNode* node) {
|
||||
const TfLiteTensor* input = GetInput(context, node, 0);
|
||||
TfLiteTensor* output = GetOutput(context, node, 0);
|
||||
const TfLiteTensor* input;
|
||||
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 0, &input));
|
||||
TfLiteTensor* output;
|
||||
TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, 0, &output));
|
||||
ReluOpData* data = reinterpret_cast<ReluOpData*>(node->user_data);
|
||||
switch (input->type) {
|
||||
case kTfLiteFloat32: {
|
||||
@ -845,8 +872,10 @@ TfLiteStatus Relu6Eval(TfLiteContext* context, TfLiteNode* node) {
|
||||
template <KernelType kernel_type>
|
||||
TfLiteStatus TanhEval(TfLiteContext* context, TfLiteNode* node) {
|
||||
OpData* data = reinterpret_cast<OpData*>(node->user_data);
|
||||
const TfLiteTensor* input = GetInput(context, node, 0);
|
||||
TfLiteTensor* output = GetOutput(context, node, 0);
|
||||
const TfLiteTensor* input;
|
||||
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 0, &input));
|
||||
TfLiteTensor* output;
|
||||
TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, 0, &output));
|
||||
switch (input->type) {
|
||||
case kTfLiteFloat32: {
|
||||
if (kernel_type == kReference) {
|
||||
@ -919,8 +948,10 @@ template <KernelType kernel_type>
|
||||
TfLiteStatus SigmoidEval(TfLiteContext* context, TfLiteNode* node) {
|
||||
OpData* data = reinterpret_cast<OpData*>(node->user_data);
|
||||
|
||||
const TfLiteTensor* input = GetInput(context, node, 0);
|
||||
TfLiteTensor* output = GetOutput(context, node, 0);
|
||||
const TfLiteTensor* input;
|
||||
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 0, &input));
|
||||
TfLiteTensor* output;
|
||||
TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, 0, &output));
|
||||
switch (input->type) {
|
||||
case kTfLiteFloat32: {
|
||||
if (kernel_type == kReference) {
|
||||
@ -1067,8 +1098,10 @@ TfLiteStatus SoftmaxEval(TfLiteContext* context, TfLiteNode* node) {
|
||||
auto* params = reinterpret_cast<TfLiteSoftmaxParams*>(node->builtin_data);
|
||||
SoftmaxOpData* data = reinterpret_cast<SoftmaxOpData*>(node->user_data);
|
||||
|
||||
const TfLiteTensor* input = GetInput(context, node, 0);
|
||||
TfLiteTensor* output = GetOutput(context, node, 0);
|
||||
const TfLiteTensor* input;
|
||||
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 0, &input));
|
||||
TfLiteTensor* output;
|
||||
TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, 0, &output));
|
||||
|
||||
switch (input->type) {
|
||||
case kTfLiteFloat32: {
|
||||
@ -1122,8 +1155,10 @@ template <KernelType kernel_type>
|
||||
TfLiteStatus LogSoftmaxEval(TfLiteContext* context, TfLiteNode* node) {
|
||||
const LogSoftmaxOpData* data =
|
||||
reinterpret_cast<LogSoftmaxOpData*>(node->user_data);
|
||||
const TfLiteTensor* input = GetInput(context, node, 0);
|
||||
TfLiteTensor* output = GetOutput(context, node, 0);
|
||||
const TfLiteTensor* input;
|
||||
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 0, &input));
|
||||
TfLiteTensor* output;
|
||||
TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, 0, &output));
|
||||
switch (input->type) {
|
||||
case kTfLiteFloat32: {
|
||||
SoftmaxParams op_params;
|
||||
@ -1183,9 +1218,12 @@ T ApplyPrelu(T input, T alpha) {
|
||||
|
||||
template <KernelType kernel_type>
|
||||
TfLiteStatus PreluEval(TfLiteContext* context, TfLiteNode* node) {
|
||||
const TfLiteTensor* input = GetInput(context, node, 0);
|
||||
const TfLiteTensor* alpha = GetInput(context, node, 1);
|
||||
TfLiteTensor* output = GetOutput(context, node, 0);
|
||||
const TfLiteTensor* input;
|
||||
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 0, &input));
|
||||
const TfLiteTensor* alpha;
|
||||
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 1, &alpha));
|
||||
TfLiteTensor* output;
|
||||
TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, 0, &output));
|
||||
const PreluOpData* data = reinterpret_cast<PreluOpData*>(node->user_data);
|
||||
switch (input->type) {
|
||||
case kTfLiteFloat32: {
|
||||
@ -1294,8 +1332,10 @@ void QuantizeLeakyRelu(const TfLiteTensor* input, TfLiteTensor* output,
|
||||
}
|
||||
|
||||
TfLiteStatus LeakyReluEval(TfLiteContext* context, TfLiteNode* node) {
|
||||
const TfLiteTensor* input = GetInput(context, node, 0);
|
||||
TfLiteTensor* output = GetOutput(context, node, 0);
|
||||
const TfLiteTensor* input;
|
||||
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 0, &input));
|
||||
TfLiteTensor* output;
|
||||
TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, 0, &output));
|
||||
const auto* params =
|
||||
reinterpret_cast<TfLiteLeakyReluParams*>(node->builtin_data);
|
||||
const LeakyReluOpData* data =
|
||||
@ -1332,8 +1372,10 @@ TfLiteStatus LeakyReluEval(TfLiteContext* context, TfLiteNode* node) {
|
||||
}
|
||||
|
||||
TfLiteStatus EluPrepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
const TfLiteTensor* input = GetInput(context, node, 0);
|
||||
TfLiteTensor* output = GetOutput(context, node, 0);
|
||||
const TfLiteTensor* input;
|
||||
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 0, &input));
|
||||
TfLiteTensor* output;
|
||||
TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, 0, &output));
|
||||
OpData* data = reinterpret_cast<OpData*>(node->user_data);
|
||||
|
||||
// Use LUT to handle quantized elu path.
|
||||
@ -1346,8 +1388,10 @@ TfLiteStatus EluPrepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
}
|
||||
|
||||
TfLiteStatus EluEval(TfLiteContext* context, TfLiteNode* node) {
|
||||
const TfLiteTensor* input = GetInput(context, node, 0);
|
||||
TfLiteTensor* output = GetOutput(context, node, 0);
|
||||
const TfLiteTensor* input;
|
||||
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 0, &input));
|
||||
TfLiteTensor* output;
|
||||
TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, 0, &output));
|
||||
switch (input->type) {
|
||||
case kTfLiteFloat32: {
|
||||
optimized_ops::Elu(GetTensorShape(input), GetTensorData<float>(input),
|
||||
|
@ -91,9 +91,15 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
TF_LITE_ENSURE_EQ(context, NumInputs(node), 2);
|
||||
TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
|
||||
|
||||
const TfLiteTensor* input1 = GetInput(context, node, kInputTensor1);
|
||||
const TfLiteTensor* input2 = GetInput(context, node, kInputTensor2);
|
||||
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
|
||||
const TfLiteTensor* input1;
|
||||
TF_LITE_ENSURE_OK(context,
|
||||
GetInputSafe(context, node, kInputTensor1, &input1));
|
||||
const TfLiteTensor* input2;
|
||||
TF_LITE_ENSURE_OK(context,
|
||||
GetInputSafe(context, node, kInputTensor2, &input2));
|
||||
TfLiteTensor* output;
|
||||
TF_LITE_ENSURE_OK(context,
|
||||
GetOutputSafe(context, node, kOutputTensor, &output));
|
||||
|
||||
TF_LITE_ENSURE_TYPES_EQ(context, input1->type, input2->type);
|
||||
output->type = input2->type;
|
||||
@ -358,9 +364,15 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
||||
auto* params = reinterpret_cast<TfLiteAddParams*>(node->builtin_data);
|
||||
OpData* data = reinterpret_cast<OpData*>(node->user_data);
|
||||
|
||||
const TfLiteTensor* input1 = GetInput(context, node, kInputTensor1);
|
||||
const TfLiteTensor* input2 = GetInput(context, node, kInputTensor2);
|
||||
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
|
||||
const TfLiteTensor* input1;
|
||||
TF_LITE_ENSURE_OK(context,
|
||||
GetInputSafe(context, node, kInputTensor1, &input1));
|
||||
const TfLiteTensor* input2;
|
||||
TF_LITE_ENSURE_OK(context,
|
||||
GetInputSafe(context, node, kInputTensor2, &input2));
|
||||
TfLiteTensor* output;
|
||||
TF_LITE_ENSURE_OK(context,
|
||||
GetOutputSafe(context, node, kOutputTensor, &output));
|
||||
|
||||
if (output->type == kTfLiteFloat32 || output->type == kTfLiteInt32) {
|
||||
EvalAdd<kernel_type>(context, node, params, data, input1, input2, output);
|
||||
|
@ -33,13 +33,18 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
TF_LITE_ENSURE(context, num_inputs >= 2);
|
||||
TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
|
||||
|
||||
const TfLiteTensor* input1 = GetInput(context, node, kInputTensor1);
|
||||
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
|
||||
const TfLiteTensor* input1;
|
||||
TF_LITE_ENSURE_OK(context,
|
||||
GetInputSafe(context, node, kInputTensor1, &input1));
|
||||
TfLiteTensor* output;
|
||||
TF_LITE_ENSURE_OK(context,
|
||||
GetOutputSafe(context, node, kOutputTensor, &output));
|
||||
output->type = input1->type;
|
||||
|
||||
// Check that all input tensors have the same shape and type.
|
||||
for (int i = kInputTensor1 + 1; i < num_inputs; ++i) {
|
||||
const TfLiteTensor* input = GetInput(context, node, i);
|
||||
const TfLiteTensor* input;
|
||||
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, i, &input));
|
||||
TF_LITE_ENSURE(context, HaveSameShapes(input1, input));
|
||||
TF_LITE_ENSURE_TYPES_EQ(context, input1->type, input->type);
|
||||
}
|
||||
@ -55,15 +60,22 @@ template <typename T>
|
||||
void EvalAddN(TfLiteContext* context, TfLiteNode* node) {
|
||||
// TODO(haoliang): Initialize all_inputs only once during init.
|
||||
VectorOfTensors<T> all_inputs(*context, *node->inputs);
|
||||
// Safe to use unchecked since caller checks that tensor is valid
|
||||
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
|
||||
int num_inputs = NumInputs(node);
|
||||
// Safe to use unchecked since caller checks that tensor is valid
|
||||
const TfLiteTensor* input1 = GetInput(context, node, kInputTensor1);
|
||||
reference_ops::AddN<T>(GetTensorShape(input1), num_inputs, all_inputs.data(),
|
||||
GetTensorData<T>(output));
|
||||
}
|
||||
|
||||
TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
||||
const TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
|
||||
const TfLiteTensor* input1;
|
||||
TF_LITE_ENSURE_OK(context,
|
||||
GetInputSafe(context, node, kInputTensor1, &input1));
|
||||
TfLiteTensor* output;
|
||||
TF_LITE_ENSURE_OK(context,
|
||||
GetOutputSafe(context, node, kOutputTensor, &output));
|
||||
if (output->type == kTfLiteFloat32) {
|
||||
EvalAddN<float>(context, node);
|
||||
} else if (output->type == kTfLiteInt32) {
|
||||
|
@ -58,15 +58,19 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
TF_LITE_ENSURE_EQ(context, NumInputs(node), 2);
|
||||
TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
|
||||
|
||||
const TfLiteTensor* input = GetInput(context, node, kInputTensor);
|
||||
const TfLiteTensor* axis = GetInput(context, node, kAxis);
|
||||
const TfLiteTensor* input;
|
||||
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputTensor, &input));
|
||||
const TfLiteTensor* axis;
|
||||
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kAxis, &axis));
|
||||
// Make sure the axis is only 1 dimension.
|
||||
TF_LITE_ENSURE_EQ(context, NumElements(axis), 1);
|
||||
// Make sure the axis is only either int32 or int64.
|
||||
TF_LITE_ENSURE(context,
|
||||
axis->type == kTfLiteInt32 || axis->type == kTfLiteInt64);
|
||||
|
||||
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
|
||||
TfLiteTensor* output;
|
||||
TF_LITE_ENSURE_OK(context,
|
||||
GetOutputSafe(context, node, kOutputTensor, &output));
|
||||
|
||||
auto* params = reinterpret_cast<TfLiteArgMaxParams*>(node->builtin_data);
|
||||
switch (params->output_type) {
|
||||
@ -119,9 +123,13 @@ std::function<bool(T, T)> GetComparefunction(bool is_arg_max) {
|
||||
}
|
||||
|
||||
TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node, bool is_arg_max) {
|
||||
const TfLiteTensor* input = GetInput(context, node, kInputTensor);
|
||||
const TfLiteTensor* axis = GetInput(context, node, kAxis);
|
||||
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
|
||||
const TfLiteTensor* input;
|
||||
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputTensor, &input));
|
||||
const TfLiteTensor* axis;
|
||||
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kAxis, &axis));
|
||||
TfLiteTensor* output;
|
||||
TF_LITE_ENSURE_OK(context,
|
||||
GetOutputSafe(context, node, kOutputTensor, &output));
|
||||
if (IsDynamicTensor(output)) {
|
||||
TF_LITE_ENSURE_STATUS(ResizeOutput(context, input, axis, output));
|
||||
}
|
||||
|
@ -40,8 +40,9 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
// everything still works fine when variable ops aren't used.
|
||||
TF_LITE_ENSURE_EQ(context, NumOutputs(node), 0);
|
||||
|
||||
const TfLiteTensor* input_resource_id_tensor =
|
||||
GetInput(context, node, kInputVariableId);
|
||||
const TfLiteTensor* input_resource_id_tensor;
|
||||
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputVariableId,
|
||||
&input_resource_id_tensor));
|
||||
TF_LITE_ENSURE_EQ(context, input_resource_id_tensor->type, kTfLiteInt32);
|
||||
TF_LITE_ENSURE_EQ(context, NumElements(input_resource_id_tensor), 1);
|
||||
|
||||
@ -51,9 +52,12 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
||||
Subgraph* subgraph = reinterpret_cast<Subgraph*>(context->impl_);
|
||||
|
||||
const TfLiteTensor* input_resource_id_tensor =
|
||||
GetInput(context, node, kInputVariableId);
|
||||
const TfLiteTensor* input_value_tensor = GetInput(context, node, kInputValue);
|
||||
const TfLiteTensor* input_resource_id_tensor;
|
||||
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputVariableId,
|
||||
&input_resource_id_tensor));
|
||||
const TfLiteTensor* input_value_tensor;
|
||||
TF_LITE_ENSURE_OK(
|
||||
context, GetInputSafe(context, node, kInputValue, &input_value_tensor));
|
||||
|
||||
int resource_id = input_resource_id_tensor->data.i32[0];
|
||||
auto& resources = subgraph->resources();
|
||||
|
@ -76,8 +76,11 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
TF_LITE_ENSURE_EQ(context, NumInputs(node), 1);
|
||||
TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
|
||||
|
||||
const TfLiteTensor* input = GetInput(context, node, kInputTensor);
|
||||
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
|
||||
const TfLiteTensor* input;
|
||||
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputTensor, &input));
|
||||
TfLiteTensor* output;
|
||||
TF_LITE_ENSURE_OK(context,
|
||||
GetOutputSafe(context, node, kOutputTensor, &output));
|
||||
|
||||
TF_LITE_ENSURE_EQ(context, NumDimensions(input), 2);
|
||||
|
||||
@ -106,8 +109,11 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
||||
auto* params =
|
||||
reinterpret_cast<TfLiteAudioSpectrogramParams*>(node->user_data);
|
||||
|
||||
const TfLiteTensor* input = GetInput(context, node, kInputTensor);
|
||||
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
|
||||
const TfLiteTensor* input;
|
||||
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputTensor, &input));
|
||||
TfLiteTensor* output;
|
||||
TF_LITE_ENSURE_OK(context,
|
||||
GetOutputSafe(context, node, kOutputTensor, &output));
|
||||
|
||||
TF_LITE_ENSURE(context, params->spectrogram->Initialize(params->window_size,
|
||||
params->stride));
|
||||
|
@ -60,13 +60,20 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
TF_LITE_ENSURE_EQ(context, node->inputs->size, 5);
|
||||
TF_LITE_ENSURE_EQ(context, node->outputs->size, 1);
|
||||
|
||||
const TfLiteTensor* input = GetInput(context, node, kInputTensor);
|
||||
const TfLiteTensor* input_weights = GetInput(context, node, kWeightsTensor);
|
||||
const TfLiteTensor* recurrent_weights =
|
||||
GetInput(context, node, kRecurrentWeightsTensor);
|
||||
const TfLiteTensor* bias = GetInput(context, node, kBiasTensor);
|
||||
const TfLiteTensor* hidden_state =
|
||||
GetInput(context, node, kHiddenStateTensor);
|
||||
const TfLiteTensor* input;
|
||||
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputTensor, &input));
|
||||
const TfLiteTensor* input_weights;
|
||||
TF_LITE_ENSURE_OK(
|
||||
context, GetInputSafe(context, node, kWeightsTensor, &input_weights));
|
||||
const TfLiteTensor* recurrent_weights;
|
||||
TF_LITE_ENSURE_OK(
|
||||
context,
|
||||
GetInputSafe(context, node, kRecurrentWeightsTensor, &recurrent_weights));
|
||||
const TfLiteTensor* bias;
|
||||
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kBiasTensor, &bias));
|
||||
const TfLiteTensor* hidden_state;
|
||||
TF_LITE_ENSURE_OK(
|
||||
context, GetInputSafe(context, node, kHiddenStateTensor, &hidden_state));
|
||||
|
||||
// Check all the parameters of tensor match within themselves and match the
|
||||
// input configuration.
|
||||
@ -86,7 +93,9 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
TF_LITE_ENSURE_EQ(context, hidden_state->dims->data[0], batch_size);
|
||||
TF_LITE_ENSURE_EQ(context, hidden_state->dims->data[1], num_units);
|
||||
|
||||
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
|
||||
TfLiteTensor* output;
|
||||
TF_LITE_ENSURE_OK(context,
|
||||
GetOutputSafe(context, node, kOutputTensor, &output));
|
||||
|
||||
// Resize output.
|
||||
TfLiteIntArray* output_size_array = TfLiteIntArrayCreate(2);
|
||||
@ -105,7 +114,9 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
TfLiteIntArrayFree(node->temporaries);
|
||||
node->temporaries = TfLiteIntArrayCreate(6);
|
||||
node->temporaries->data[0] = op_data->scratch_tensor_index;
|
||||
TfLiteTensor* input_quantized = GetTemporary(context, node, /*index=*/0);
|
||||
TfLiteTensor* input_quantized;
|
||||
TF_LITE_ENSURE_OK(context, GetTemporarySafe(context, node, /*index=*/0,
|
||||
&input_quantized));
|
||||
input_quantized->type = input_weights->type;
|
||||
input_quantized->allocation_type = kTfLiteArenaRw;
|
||||
if (!TfLiteIntArrayEqual(input_quantized->dims, input->dims)) {
|
||||
@ -114,8 +125,9 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
input_quantized_size));
|
||||
}
|
||||
node->temporaries->data[1] = op_data->scratch_tensor_index + 1;
|
||||
TfLiteTensor* hidden_state_quantized =
|
||||
GetTemporary(context, node, /*index=*/1);
|
||||
TfLiteTensor* hidden_state_quantized;
|
||||
TF_LITE_ENSURE_OK(context, GetTemporarySafe(context, node, /*index=*/1,
|
||||
&hidden_state_quantized));
|
||||
hidden_state_quantized->type = input_weights->type;
|
||||
hidden_state_quantized->allocation_type = kTfLiteArenaRw;
|
||||
if (!TfLiteIntArrayEqual(hidden_state_quantized->dims,
|
||||
@ -127,7 +139,9 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
hidden_state_quantized_size));
|
||||
}
|
||||
node->temporaries->data[2] = op_data->scratch_tensor_index + 2;
|
||||
TfLiteTensor* scaling_factors = GetTemporary(context, node, /*index=*/2);
|
||||
TfLiteTensor* scaling_factors;
|
||||
TF_LITE_ENSURE_OK(context, GetTemporarySafe(context, node, /*index=*/2,
|
||||
&scaling_factors));
|
||||
scaling_factors->type = kTfLiteFloat32;
|
||||
scaling_factors->allocation_type = kTfLiteArenaRw;
|
||||
int scaling_dims[1] = {batch_size};
|
||||
@ -138,7 +152,9 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
scaling_factors_size));
|
||||
}
|
||||
node->temporaries->data[3] = op_data->scratch_tensor_index + 3;
|
||||
TfLiteTensor* accum_scratch = GetTemporary(context, node, /*index=*/3);
|
||||
TfLiteTensor* accum_scratch;
|
||||
TF_LITE_ENSURE_OK(
|
||||
context, GetTemporarySafe(context, node, /*index=*/3, &accum_scratch));
|
||||
accum_scratch->type = kTfLiteInt32;
|
||||
accum_scratch->allocation_type = kTfLiteArenaRw;
|
||||
int accum_scratch_dims[2] = {num_units, batch_size};
|
||||
@ -151,7 +167,9 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
accum_scratch_size));
|
||||
}
|
||||
node->temporaries->data[4] = op_data->scratch_tensor_index + 4;
|
||||
TfLiteTensor* zero_points = GetTemporary(context, node, /*index=*/4);
|
||||
TfLiteTensor* zero_points;
|
||||
TF_LITE_ENSURE_OK(
|
||||
context, GetTemporarySafe(context, node, /*index=*/4, &zero_points));
|
||||
zero_points->type = kTfLiteInt32;
|
||||
zero_points->allocation_type = kTfLiteArenaRw;
|
||||
int zero_points_dims[1] = {batch_size};
|
||||
@ -162,7 +180,9 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
zero_points_size));
|
||||
}
|
||||
node->temporaries->data[5] = op_data->scratch_tensor_index + 5;
|
||||
TfLiteTensor* row_sums = GetTemporary(context, node, /*index=*/5);
|
||||
TfLiteTensor* row_sums;
|
||||
TF_LITE_ENSURE_OK(context,
|
||||
GetTemporarySafe(context, node, /*index=*/5, &row_sums));
|
||||
row_sums->type = kTfLiteInt32;
|
||||
row_sums->allocation_type = kTfLiteArenaRwPersistent;
|
||||
int row_sums_dims[2] = {2, num_units};
|
||||
@ -260,14 +280,23 @@ TfLiteStatus EvalHybrid(const TfLiteTensor* input,
|
||||
TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
||||
auto* params = reinterpret_cast<TfLiteRNNParams*>(node->builtin_data);
|
||||
auto* op_data = reinterpret_cast<OpData*>(node->user_data);
|
||||
const TfLiteTensor* input = GetInput(context, node, kInputTensor);
|
||||
const TfLiteTensor* input_weights = GetInput(context, node, kWeightsTensor);
|
||||
const TfLiteTensor* recurrent_weights =
|
||||
GetInput(context, node, kRecurrentWeightsTensor);
|
||||
const TfLiteTensor* bias = GetInput(context, node, kBiasTensor);
|
||||
const TfLiteTensor* input;
|
||||
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputTensor, &input));
|
||||
const TfLiteTensor* input_weights;
|
||||
TF_LITE_ENSURE_OK(
|
||||
context, GetInputSafe(context, node, kWeightsTensor, &input_weights));
|
||||
const TfLiteTensor* recurrent_weights;
|
||||
TF_LITE_ENSURE_OK(
|
||||
context,
|
||||
GetInputSafe(context, node, kRecurrentWeightsTensor, &recurrent_weights));
|
||||
const TfLiteTensor* bias;
|
||||
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kBiasTensor, &bias));
|
||||
TfLiteTensor* hidden_state =
|
||||
&context->tensors[node->inputs->data[kHiddenStateTensor]];
|
||||
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
|
||||
GetVariableInput(context, node, kHiddenStateTensor);
|
||||
TF_LITE_ENSURE(context, hidden_state != nullptr);
|
||||
TfLiteTensor* output;
|
||||
TF_LITE_ENSURE_OK(context,
|
||||
GetOutputSafe(context, node, kOutputTensor, &output));
|
||||
|
||||
// We already checked that weight types are consistent, so branch on one.
|
||||
switch (input_weights->type) {
|
||||
@ -277,12 +306,23 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
||||
case kTfLiteUInt8:
|
||||
case kTfLiteInt8: {
|
||||
// TODO(mirkov): implement eval with quantized inputs as well.
|
||||
TfLiteTensor* input_quantized = GetTemporary(context, node, 0);
|
||||
TfLiteTensor* hidden_state_quantized = GetTemporary(context, node, 1);
|
||||
TfLiteTensor* scaling_factors = GetTemporary(context, node, 2);
|
||||
TfLiteTensor* accum_scratch = GetTemporary(context, node, 3);
|
||||
TfLiteTensor* zero_points = GetTemporary(context, node, 4);
|
||||
TfLiteTensor* row_sums = GetTemporary(context, node, 5);
|
||||
TfLiteTensor* input_quantized;
|
||||
TF_LITE_ENSURE_OK(context,
|
||||
GetTemporarySafe(context, node, 0, &input_quantized));
|
||||
TfLiteTensor* hidden_state_quantized;
|
||||
TF_LITE_ENSURE_OK(
|
||||
context, GetTemporarySafe(context, node, 1, &hidden_state_quantized));
|
||||
TfLiteTensor* scaling_factors;
|
||||
TF_LITE_ENSURE_OK(context,
|
||||
GetTemporarySafe(context, node, 2, &scaling_factors));
|
||||
TfLiteTensor* accum_scratch;
|
||||
TF_LITE_ENSURE_OK(context,
|
||||
GetTemporarySafe(context, node, 3, &accum_scratch));
|
||||
TfLiteTensor* zero_points;
|
||||
TF_LITE_ENSURE_OK(context,
|
||||
GetTemporarySafe(context, node, 4, &zero_points));
|
||||
TfLiteTensor* row_sums;
|
||||
TF_LITE_ENSURE_OK(context, GetTemporarySafe(context, node, 5, &row_sums));
|
||||
return EvalHybrid(input, input_weights, recurrent_weights, bias, params,
|
||||
input_quantized, hidden_state_quantized,
|
||||
scaling_factors, hidden_state, output, zero_points,
|
||||
|
@ -154,7 +154,9 @@ TfLiteStatus InitializeTemporaries(TfLiteContext* context, TfLiteNode* node,
|
||||
// Temp tensor for Transposed LHS;
|
||||
{
|
||||
node->temporaries->data[0] = op_data->scratch_tensor_index;
|
||||
TfLiteTensor* scratch_buffer = GetTemporary(context, node, /*index=*/0);
|
||||
TfLiteTensor* scratch_buffer;
|
||||
TF_LITE_ENSURE_OK(
|
||||
context, GetTemporarySafe(context, node, /*index=*/0, &scratch_buffer));
|
||||
TfLiteIntArray* scratch_buffer_size = TfLiteIntArrayCreate(lhs_rank);
|
||||
for (int i = 0; i < lhs_rank - 2; ++i) {
|
||||
scratch_buffer_size->data[i] = lhs->dims->data[i];
|
||||
@ -175,7 +177,9 @@ TfLiteStatus InitializeTemporaries(TfLiteContext* context, TfLiteNode* node,
|
||||
// is set by the caller, the data is already in the desired layout.
|
||||
{
|
||||
node->temporaries->data[1] = op_data->scratch_tensor_index + 1;
|
||||
TfLiteTensor* scratch_buffer = GetTemporary(context, node, /*index=*/1);
|
||||
TfLiteTensor* scratch_buffer;
|
||||
TF_LITE_ENSURE_OK(
|
||||
context, GetTemporarySafe(context, node, /*index=*/1, &scratch_buffer));
|
||||
const TfLiteTensor* rhs = op_context->rhs;
|
||||
int rhs_rank = NumDimensions(rhs);
|
||||
TfLiteIntArray* scratch_buffer_size = TfLiteIntArrayCreate(rhs_rank);
|
||||
@ -215,7 +219,9 @@ TfLiteStatus InitializeTemporaries(TfLiteContext* context, TfLiteNode* node,
|
||||
}
|
||||
op_data->compute_row_sums = true;
|
||||
node->temporaries->data[2] = op_data->scratch_tensor_index + 2;
|
||||
TfLiteTensor* input_quantized = GetTemporary(context, node, /*index=*/2);
|
||||
TfLiteTensor* input_quantized;
|
||||
TF_LITE_ENSURE_OK(context, GetTemporarySafe(context, node, /*index=*/2,
|
||||
&input_quantized));
|
||||
input_quantized->type = op_context->rhs->type;
|
||||
input_quantized->allocation_type = kTfLiteArenaRw;
|
||||
|
||||
@ -225,7 +231,9 @@ TfLiteStatus InitializeTemporaries(TfLiteContext* context, TfLiteNode* node,
|
||||
input_quantized_size));
|
||||
|
||||
node->temporaries->data[3] = op_data->scratch_tensor_index + 3;
|
||||
TfLiteTensor* scaling_factors = GetTemporary(context, node, /*index=*/3);
|
||||
TfLiteTensor* scaling_factors;
|
||||
TF_LITE_ENSURE_OK(context, GetTemporarySafe(context, node, /*index=*/3,
|
||||
&scaling_factors));
|
||||
scaling_factors->type = kTfLiteFloat32;
|
||||
scaling_factors->allocation_type = kTfLiteArenaRw;
|
||||
// Total size of scaling factors is batch size * number of total batches
|
||||
@ -238,7 +246,9 @@ TfLiteStatus InitializeTemporaries(TfLiteContext* context, TfLiteNode* node,
|
||||
}
|
||||
|
||||
node->temporaries->data[4] = op_data->scratch_tensor_index + 4;
|
||||
TfLiteTensor* accum_scratch = GetTemporary(context, node, /*index=*/4);
|
||||
TfLiteTensor* accum_scratch;
|
||||
TF_LITE_ENSURE_OK(
|
||||
context, GetTemporarySafe(context, node, /*index=*/4, &accum_scratch));
|
||||
accum_scratch->type = kTfLiteInt32;
|
||||
accum_scratch->allocation_type = kTfLiteArenaRw;
|
||||
int accum_scratch_dims[2] = {num_units, batch_size};
|
||||
@ -252,7 +262,9 @@ TfLiteStatus InitializeTemporaries(TfLiteContext* context, TfLiteNode* node,
|
||||
}
|
||||
|
||||
node->temporaries->data[5] = op_data->scratch_tensor_index + 5;
|
||||
TfLiteTensor* input_offsets = GetTemporary(context, node, /*index=*/5);
|
||||
TfLiteTensor* input_offsets;
|
||||
TF_LITE_ENSURE_OK(
|
||||
context, GetTemporarySafe(context, node, /*index=*/5, &input_offsets));
|
||||
input_offsets->type = kTfLiteInt32;
|
||||
input_offsets->allocation_type = kTfLiteArenaRw;
|
||||
if (!TfLiteIntArrayEqualsArray(input_offsets->dims, 1, scaling_dims)) {
|
||||
@ -262,7 +274,9 @@ TfLiteStatus InitializeTemporaries(TfLiteContext* context, TfLiteNode* node,
|
||||
input_offsets_size));
|
||||
}
|
||||
node->temporaries->data[6] = op_data->scratch_tensor_index + 6;
|
||||
TfLiteTensor* row_sums = GetTemporary(context, node, /*index=*/6);
|
||||
TfLiteTensor* row_sums;
|
||||
TF_LITE_ENSURE_OK(context,
|
||||
GetTemporarySafe(context, node, /*index=*/6, &row_sums));
|
||||
row_sums->type = kTfLiteInt32;
|
||||
row_sums->allocation_type = kTfLiteArenaRwPersistent;
|
||||
int row_sums_dims[1] = {num_weights_matrices * num_units};
|
||||
@ -288,9 +302,15 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
bool adj_x = op_context.params->adj_x;
|
||||
bool adj_y = op_context.params->adj_y;
|
||||
|
||||
const TfLiteTensor* lhs_data = GetInput(context, node, kInputLHSTensor);
|
||||
const TfLiteTensor* rhs_data = GetInput(context, node, kInputRHSTensor);
|
||||
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
|
||||
const TfLiteTensor* lhs_data;
|
||||
TF_LITE_ENSURE_OK(context,
|
||||
GetInputSafe(context, node, kInputLHSTensor, &lhs_data));
|
||||
const TfLiteTensor* rhs_data;
|
||||
TF_LITE_ENSURE_OK(context,
|
||||
GetInputSafe(context, node, kInputRHSTensor, &rhs_data));
|
||||
TfLiteTensor* output;
|
||||
TF_LITE_ENSURE_OK(context,
|
||||
GetOutputSafe(context, node, kOutputTensor, &output));
|
||||
|
||||
// Note that quantized inference requires that all tensors have their
|
||||
// parameters set. This is usually done during quantized training.
|
||||
@ -502,11 +522,21 @@ TfLiteStatus EvalQuantized(TfLiteContext* context, TfLiteNode* node,
|
||||
const RuntimeShape& rhs_shape,
|
||||
const TfLiteTensor* rhs, TfLiteTensor* output) {
|
||||
if (lhs->type == kTfLiteFloat32) {
|
||||
TfLiteTensor* input_quantized = GetTemporary(context, node, /*index=*/2);
|
||||
TfLiteTensor* scaling_factors = GetTemporary(context, node, /*index=*/3);
|
||||
TfLiteTensor* accum_scratch = GetTemporary(context, node, /*index=*/4);
|
||||
TfLiteTensor* input_offsets = GetTemporary(context, node, /*index=*/5);
|
||||
TfLiteTensor* row_sums = GetTemporary(context, node, /*index=*/6);
|
||||
TfLiteTensor* input_quantized;
|
||||
TF_LITE_ENSURE_OK(context, GetTemporarySafe(context, node, /*index=*/2,
|
||||
&input_quantized));
|
||||
TfLiteTensor* scaling_factors;
|
||||
TF_LITE_ENSURE_OK(context, GetTemporarySafe(context, node, /*index=*/3,
|
||||
&scaling_factors));
|
||||
TfLiteTensor* accum_scratch;
|
||||
TF_LITE_ENSURE_OK(
|
||||
context, GetTemporarySafe(context, node, /*index=*/4, &accum_scratch));
|
||||
TfLiteTensor* input_offsets;
|
||||
TF_LITE_ENSURE_OK(
|
||||
context, GetTemporarySafe(context, node, /*index=*/5, &input_offsets));
|
||||
TfLiteTensor* row_sums;
|
||||
TF_LITE_ENSURE_OK(context,
|
||||
GetTemporarySafe(context, node, /*index=*/6, &row_sums));
|
||||
return EvalHybrid<kernel_type>(
|
||||
context, node, data, lhs_shape, lhs, rhs_shape, rhs, input_quantized,
|
||||
scaling_factors, accum_scratch, row_sums, input_offsets, output);
|
||||
@ -524,6 +554,10 @@ TfLiteStatus EvalQuantized(TfLiteContext* context, TfLiteNode* node,
|
||||
TfLiteTensor* GetTempRhs(TfLiteContext* context, TfLiteNode* node,
|
||||
const TfLiteTensor* rhs) {
|
||||
TfLiteTensor* transposed_rhs = GetTemporary(context, node, 1);
|
||||
if (transposed_rhs == nullptr) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
if (rhs->type == kTfLiteInt8) {
|
||||
// Get the quantization params from the RHS tensor.
|
||||
transposed_rhs->params.scale = rhs->params.scale;
|
||||
@ -535,6 +569,10 @@ TfLiteTensor* GetTempRhs(TfLiteContext* context, TfLiteNode* node,
|
||||
TfLiteTensor* GetTempLhs(TfLiteContext* context, TfLiteNode* node,
|
||||
const TfLiteTensor* lhs) {
|
||||
TfLiteTensor* transposed_lhs = GetTemporary(context, node, 0);
|
||||
if (transposed_lhs == nullptr) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
if (lhs->type == kTfLiteInt8) {
|
||||
// Get the quantization params from the LHS tensor.
|
||||
transposed_lhs->params.scale = lhs->params.scale;
|
||||
@ -558,9 +596,15 @@ template <KernelType kernel_type>
|
||||
TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
||||
OpContext op_context(context, node);
|
||||
OpData* op_data = reinterpret_cast<OpData*>(node->user_data);
|
||||
const TfLiteTensor* lhs = GetInput(context, node, kInputLHSTensor);
|
||||
const TfLiteTensor* rhs = GetInput(context, node, kInputRHSTensor);
|
||||
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
|
||||
const TfLiteTensor* lhs;
|
||||
TF_LITE_ENSURE_OK(context,
|
||||
GetInputSafe(context, node, kInputLHSTensor, &lhs));
|
||||
const TfLiteTensor* rhs;
|
||||
TF_LITE_ENSURE_OK(context,
|
||||
GetInputSafe(context, node, kInputRHSTensor, &rhs));
|
||||
TfLiteTensor* output;
|
||||
TF_LITE_ENSURE_OK(context,
|
||||
GetOutputSafe(context, node, kOutputTensor, &output));
|
||||
RuntimeShape orig_lhs_shape = GetTensorShape(lhs);
|
||||
RuntimeShape orig_rhs_shape = GetTensorShape(rhs);
|
||||
|
||||
|
@ -21,6 +21,7 @@ limitations under the License.
|
||||
#include "tensorflow/lite/c/builtin_op_data.h"
|
||||
#include "tensorflow/lite/c/common.h"
|
||||
#include "tensorflow/lite/kernels/cpu_backend_context.h"
|
||||
#include "tensorflow/lite/kernels/internal/compatibility.h"
|
||||
#include "tensorflow/lite/kernels/internal/kernel_utils.h"
|
||||
#include "tensorflow/lite/kernels/internal/tensor_utils.h"
|
||||
#include "tensorflow/lite/kernels/kernel_util.h"
|
||||
@ -192,8 +193,10 @@ TfLiteStatus CheckLstmTensorDimensionsAndTypes(
|
||||
TF_LITE_ENSURE(context, params->cell_clip >= 0);
|
||||
TF_LITE_ENSURE(context, params->proj_clip >= 0);
|
||||
|
||||
const TfLiteTensor* input_to_forget_weights =
|
||||
GetInput(context, node, input_to_forget_weights_tensor);
|
||||
const TfLiteTensor* input_to_forget_weights;
|
||||
TF_LITE_ENSURE_OK(context,
|
||||
GetInputSafe(context, node, input_to_forget_weights_tensor,
|
||||
&input_to_forget_weights));
|
||||
TF_LITE_ENSURE_EQ(context, input_to_forget_weights->dims->size, 2);
|
||||
TF_LITE_ENSURE_EQ(context, input_to_forget_weights->dims->data[0], n_cell);
|
||||
TF_LITE_ENSURE_EQ(context, input_to_forget_weights->dims->data[1], n_input);
|
||||
@ -211,16 +214,20 @@ TfLiteStatus CheckLstmTensorDimensionsAndTypes(
|
||||
input_to_forget_weights->type);
|
||||
}
|
||||
|
||||
const TfLiteTensor* input_to_cell_weights =
|
||||
GetInput(context, node, input_to_cell_weights_tensor);
|
||||
const TfLiteTensor* input_to_cell_weights;
|
||||
TF_LITE_ENSURE_OK(context,
|
||||
GetInputSafe(context, node, input_to_cell_weights_tensor,
|
||||
&input_to_cell_weights));
|
||||
TF_LITE_ENSURE_EQ(context, input_to_cell_weights->dims->size, 2);
|
||||
TF_LITE_ENSURE_EQ(context, input_to_cell_weights->dims->data[0], n_cell);
|
||||
TF_LITE_ENSURE_EQ(context, input_to_cell_weights->dims->data[1], n_input);
|
||||
TF_LITE_ENSURE_TYPES_EQ(context, input_to_cell_weights->type,
|
||||
input_to_forget_weights->type);
|
||||
|
||||
const TfLiteTensor* input_to_output_weights =
|
||||
GetInput(context, node, input_to_output_weights_tensor);
|
||||
const TfLiteTensor* input_to_output_weights;
|
||||
TF_LITE_ENSURE_OK(context,
|
||||
GetInputSafe(context, node, input_to_output_weights_tensor,
|
||||
&input_to_output_weights));
|
||||
TF_LITE_ENSURE_EQ(context, input_to_output_weights->dims->size, 2);
|
||||
TF_LITE_ENSURE_EQ(context, input_to_output_weights->dims->data[0], n_cell);
|
||||
TF_LITE_ENSURE_EQ(context, input_to_output_weights->dims->data[1], n_input);
|
||||
@ -239,8 +246,10 @@ TfLiteStatus CheckLstmTensorDimensionsAndTypes(
|
||||
input_to_forget_weights->type);
|
||||
}
|
||||
|
||||
const TfLiteTensor* recurrent_to_forget_weights =
|
||||
GetInput(context, node, recurrent_to_forget_weights_tensor);
|
||||
const TfLiteTensor* recurrent_to_forget_weights;
|
||||
TF_LITE_ENSURE_OK(
|
||||
context, GetInputSafe(context, node, recurrent_to_forget_weights_tensor,
|
||||
&recurrent_to_forget_weights));
|
||||
TF_LITE_ENSURE_EQ(context, recurrent_to_forget_weights->dims->size, 2);
|
||||
TF_LITE_ENSURE_EQ(context, recurrent_to_forget_weights->dims->data[0],
|
||||
n_cell);
|
||||
@ -249,8 +258,10 @@ TfLiteStatus CheckLstmTensorDimensionsAndTypes(
|
||||
TF_LITE_ENSURE_TYPES_EQ(context, recurrent_to_forget_weights->type,
|
||||
input_to_forget_weights->type);
|
||||
|
||||
const TfLiteTensor* recurrent_to_cell_weights =
|
||||
GetInput(context, node, recurrent_to_cell_weights_tensor);
|
||||
const TfLiteTensor* recurrent_to_cell_weights;
|
||||
TF_LITE_ENSURE_OK(
|
||||
context, GetInputSafe(context, node, recurrent_to_cell_weights_tensor,
|
||||
&recurrent_to_cell_weights));
|
||||
TF_LITE_ENSURE_EQ(context, recurrent_to_cell_weights->dims->size, 2);
|
||||
TF_LITE_ENSURE_EQ(context, recurrent_to_cell_weights->dims->data[0], n_cell);
|
||||
TF_LITE_ENSURE_EQ(context, recurrent_to_cell_weights->dims->data[1],
|
||||
@ -316,20 +327,25 @@ TfLiteStatus CheckLstmTensorDimensionsAndTypes(
|
||||
TF_LITE_ENSURE_TYPES_EQ(context, input_gate_bias->type, kTfLiteFloat32);
|
||||
}
|
||||
|
||||
const TfLiteTensor* forget_gate_bias =
|
||||
GetInput(context, node, forget_gate_bias_tensor);
|
||||
const TfLiteTensor* forget_gate_bias;
|
||||
TF_LITE_ENSURE_OK(
|
||||
context,
|
||||
GetInputSafe(context, node, forget_gate_bias_tensor, &forget_gate_bias));
|
||||
TF_LITE_ENSURE_EQ(context, forget_gate_bias->dims->size, 1);
|
||||
TF_LITE_ENSURE_EQ(context, forget_gate_bias->dims->data[0], n_cell);
|
||||
TF_LITE_ENSURE_TYPES_EQ(context, forget_gate_bias->type, kTfLiteFloat32);
|
||||
|
||||
const TfLiteTensor* cell_gate_bias =
|
||||
GetInput(context, node, cell_gate_bias_tensor);
|
||||
const TfLiteTensor* cell_gate_bias;
|
||||
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, cell_gate_bias_tensor,
|
||||
&cell_gate_bias));
|
||||
TF_LITE_ENSURE_EQ(context, cell_gate_bias->dims->size, 1);
|
||||
TF_LITE_ENSURE_EQ(context, cell_gate_bias->dims->data[0], n_cell);
|
||||
TF_LITE_ENSURE_EQ(context, cell_gate_bias->type, kTfLiteFloat32);
|
||||
|
||||
const TfLiteTensor* output_gate_bias =
|
||||
GetInput(context, node, output_gate_bias_tensor);
|
||||
const TfLiteTensor* output_gate_bias;
|
||||
TF_LITE_ENSURE_OK(
|
||||
context,
|
||||
GetInputSafe(context, node, output_gate_bias_tensor, &output_gate_bias));
|
||||
TF_LITE_ENSURE_EQ(context, output_gate_bias->dims->size, 1);
|
||||
TF_LITE_ENSURE_EQ(context, output_gate_bias->dims->data[0], n_cell);
|
||||
TF_LITE_ENSURE_TYPES_EQ(context, output_gate_bias->type, kTfLiteFloat32);
|
||||
@ -413,7 +429,8 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
|
||||
// Inferring batch size, number of outputs and sequence length and
|
||||
// number of cells from the input tensors.
|
||||
const TfLiteTensor* input = GetInput(context, node, kInputTensor);
|
||||
const TfLiteTensor* input;
|
||||
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputTensor, &input));
|
||||
TF_LITE_ENSURE_TYPES_EQ(context, input->type, kTfLiteFloat32);
|
||||
TF_LITE_ENSURE_EQ(context, input->dims->size, 3);
|
||||
const bool time_major = params->time_major;
|
||||
@ -421,15 +438,19 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
const int n_batch = time_major ? input->dims->data[1] : input->dims->data[0];
|
||||
const int n_input = input->dims->data[2];
|
||||
|
||||
const TfLiteTensor* fw_input_to_output_weights =
|
||||
GetInput(context, node, kFwInputToOutputWeightsTensor);
|
||||
const TfLiteTensor* fw_input_to_output_weights;
|
||||
TF_LITE_ENSURE_OK(context,
|
||||
GetInputSafe(context, node, kFwInputToOutputWeightsTensor,
|
||||
&fw_input_to_output_weights));
|
||||
const int n_fw_cell = fw_input_to_output_weights->dims->data[0];
|
||||
TF_LITE_ENSURE_EQ(context, fw_input_to_output_weights->dims->size, 2);
|
||||
TF_LITE_ENSURE_EQ(context, fw_input_to_output_weights->dims->data[1],
|
||||
n_input);
|
||||
|
||||
const TfLiteTensor* bw_input_to_output_weights =
|
||||
GetInput(context, node, kBwInputToOutputWeightsTensor);
|
||||
const TfLiteTensor* bw_input_to_output_weights;
|
||||
TF_LITE_ENSURE_OK(context,
|
||||
GetInputSafe(context, node, kBwInputToOutputWeightsTensor,
|
||||
&bw_input_to_output_weights));
|
||||
const int n_bw_cell = bw_input_to_output_weights->dims->data[0];
|
||||
TF_LITE_ENSURE_EQ(context, bw_input_to_output_weights->dims->size, 2);
|
||||
TF_LITE_ENSURE_EQ(context, bw_input_to_output_weights->dims->data[1],
|
||||
@ -437,8 +458,10 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
TF_LITE_ENSURE_EQ(context, bw_input_to_output_weights->type,
|
||||
fw_input_to_output_weights->type);
|
||||
|
||||
const TfLiteTensor* fw_recurrent_to_output_weights =
|
||||
GetInput(context, node, kFwRecurrentToOutputWeightsTensor);
|
||||
const TfLiteTensor* fw_recurrent_to_output_weights;
|
||||
TF_LITE_ENSURE_OK(
|
||||
context, GetInputSafe(context, node, kFwRecurrentToOutputWeightsTensor,
|
||||
&fw_recurrent_to_output_weights));
|
||||
TF_LITE_ENSURE_EQ(context, fw_recurrent_to_output_weights->dims->size, 2);
|
||||
TF_LITE_ENSURE_EQ(context, fw_recurrent_to_output_weights->dims->data[0],
|
||||
n_fw_cell);
|
||||
@ -446,8 +469,10 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
fw_input_to_output_weights->type);
|
||||
const int n_fw_output = fw_recurrent_to_output_weights->dims->data[1];
|
||||
|
||||
const TfLiteTensor* bw_recurrent_to_output_weights =
|
||||
GetInput(context, node, kBwRecurrentToOutputWeightsTensor);
|
||||
const TfLiteTensor* bw_recurrent_to_output_weights;
|
||||
TF_LITE_ENSURE_OK(
|
||||
context, GetInputSafe(context, node, kBwRecurrentToOutputWeightsTensor,
|
||||
&bw_recurrent_to_output_weights));
|
||||
TF_LITE_ENSURE_EQ(context, bw_recurrent_to_output_weights->dims->size, 2);
|
||||
TF_LITE_ENSURE_EQ(context, bw_recurrent_to_output_weights->dims->data[0],
|
||||
n_bw_cell);
|
||||
@ -504,7 +529,9 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
}
|
||||
|
||||
// Get the pointer to output, activation_state and cell_state buffer tensors.
|
||||
TfLiteTensor* fw_output = GetOutput(context, node, kFwOutputTensor);
|
||||
TfLiteTensor* fw_output;
|
||||
TF_LITE_ENSURE_OK(context,
|
||||
GetOutputSafe(context, node, kFwOutputTensor, &fw_output));
|
||||
TfLiteTensor* fw_activation_state =
|
||||
GetVariableInput(context, node, kFwInputActivationStateTensor);
|
||||
TF_LITE_ENSURE(context, fw_activation_state != nullptr);
|
||||
@ -541,8 +568,9 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
// Create a scratch buffer tensor.
|
||||
node->temporaries->data[kFwScratchBuffer] =
|
||||
op_data->scratch_tensor_index + kFwScratchBuffer;
|
||||
TfLiteTensor* fw_scratch_buffer =
|
||||
GetTemporary(context, node, kFwScratchBuffer);
|
||||
TfLiteTensor* fw_scratch_buffer;
|
||||
TF_LITE_ENSURE_OK(context, GetTemporarySafe(context, node, kFwScratchBuffer,
|
||||
&fw_scratch_buffer));
|
||||
fw_scratch_buffer->type = input->type;
|
||||
fw_scratch_buffer->allocation_type = kTfLiteArenaRw;
|
||||
|
||||
@ -581,7 +609,9 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
|
||||
// Resize the output tensors.
|
||||
if (!params->merge_outputs) {
|
||||
TfLiteTensor* bw_output = GetOutput(context, node, kBwOutputTensor);
|
||||
TfLiteTensor* bw_output;
|
||||
TF_LITE_ENSURE_OK(
|
||||
context, GetOutputSafe(context, node, kBwOutputTensor, &bw_output));
|
||||
TfLiteIntArray* bw_output_size = TfLiteIntArrayCreate(3);
|
||||
bw_output_size->data[0] = time_major ? max_time : n_batch;
|
||||
bw_output_size->data[1] = time_major ? n_batch : max_time;
|
||||
@ -600,8 +630,9 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
// Create a scratch buffer tensor.
|
||||
node->temporaries->data[kBwScratchBuffer] =
|
||||
op_data->scratch_tensor_index + kBwScratchBuffer;
|
||||
TfLiteTensor* bw_scratch_buffer =
|
||||
GetTemporary(context, node, kBwScratchBuffer);
|
||||
TfLiteTensor* bw_scratch_buffer;
|
||||
TF_LITE_ENSURE_OK(context, GetTemporarySafe(context, node, kBwScratchBuffer,
|
||||
&bw_scratch_buffer));
|
||||
bw_scratch_buffer->type = input->type;
|
||||
bw_scratch_buffer->allocation_type = kTfLiteArenaRw;
|
||||
|
||||
@ -631,8 +662,9 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
// (if present), activation_state and cell_state tensors.
|
||||
node->temporaries->data[kInputQuantized] =
|
||||
op_data->scratch_tensor_index + kInputQuantized;
|
||||
TfLiteTensor* input_quantized =
|
||||
GetTemporary(context, node, kInputQuantized);
|
||||
TfLiteTensor* input_quantized;
|
||||
TF_LITE_ENSURE_OK(context, GetTemporarySafe(context, node, kInputQuantized,
|
||||
&input_quantized));
|
||||
input_quantized->type = fw_input_to_output_weights->type;
|
||||
input_quantized->allocation_type = kTfLiteArenaRw;
|
||||
if (!TfLiteIntArrayEqual(input_quantized->dims, input->dims)) {
|
||||
@ -643,8 +675,10 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
|
||||
node->temporaries->data[kFwActivationStateQuantized] =
|
||||
op_data->scratch_tensor_index + kFwActivationStateQuantized;
|
||||
TfLiteTensor* fw_activation_state_quantized =
|
||||
GetTemporary(context, node, kFwActivationStateQuantized);
|
||||
TfLiteTensor* fw_activation_state_quantized;
|
||||
TF_LITE_ENSURE_OK(
|
||||
context, GetTemporarySafe(context, node, kFwActivationStateQuantized,
|
||||
&fw_activation_state_quantized));
|
||||
fw_activation_state_quantized->type = fw_input_to_output_weights->type;
|
||||
fw_activation_state_quantized->allocation_type = kTfLiteArenaRw;
|
||||
if (!TfLiteIntArrayEqual(fw_activation_state_quantized->dims,
|
||||
@ -657,8 +691,10 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
}
|
||||
node->temporaries->data[kBwActivationStateQuantized] =
|
||||
op_data->scratch_tensor_index + kBwActivationStateQuantized;
|
||||
TfLiteTensor* bw_activation_state_quantized =
|
||||
GetTemporary(context, node, kBwActivationStateQuantized);
|
||||
TfLiteTensor* bw_activation_state_quantized;
|
||||
TF_LITE_ENSURE_OK(
|
||||
context, GetTemporarySafe(context, node, kBwActivationStateQuantized,
|
||||
&bw_activation_state_quantized));
|
||||
bw_activation_state_quantized->type = fw_input_to_output_weights->type;
|
||||
bw_activation_state_quantized->allocation_type = kTfLiteArenaRw;
|
||||
if (!TfLiteIntArrayEqual(bw_activation_state_quantized->dims,
|
||||
@ -671,8 +707,10 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
}
|
||||
node->temporaries->data[kFwCellStateQuantized] =
|
||||
op_data->scratch_tensor_index + kFwCellStateQuantized;
|
||||
TfLiteTensor* fw_cell_state_quantized =
|
||||
GetTemporary(context, node, kFwCellStateQuantized);
|
||||
TfLiteTensor* fw_cell_state_quantized;
|
||||
TF_LITE_ENSURE_OK(context,
|
||||
GetTemporarySafe(context, node, kFwCellStateQuantized,
|
||||
&fw_cell_state_quantized));
|
||||
fw_cell_state_quantized->type = fw_input_to_output_weights->type;
|
||||
fw_cell_state_quantized->allocation_type = kTfLiteArenaRw;
|
||||
if (!TfLiteIntArrayEqual(fw_cell_state_quantized->dims,
|
||||
@ -685,8 +723,10 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
}
|
||||
node->temporaries->data[kBwCellStateQuantized] =
|
||||
op_data->scratch_tensor_index + kBwCellStateQuantized;
|
||||
TfLiteTensor* bw_cell_state_quantized =
|
||||
GetTemporary(context, node, kBwCellStateQuantized);
|
||||
TfLiteTensor* bw_cell_state_quantized;
|
||||
TF_LITE_ENSURE_OK(context,
|
||||
GetTemporarySafe(context, node, kBwCellStateQuantized,
|
||||
&bw_cell_state_quantized));
|
||||
bw_cell_state_quantized->type = fw_input_to_output_weights->type;
|
||||
bw_cell_state_quantized->allocation_type = kTfLiteArenaRw;
|
||||
if (!TfLiteIntArrayEqual(bw_cell_state_quantized->dims,
|
||||
@ -705,7 +745,10 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
// the scaling factor of the matrix).
|
||||
node->temporaries->data[kInputScalingFactors] =
|
||||
op_data->scratch_tensor_index + kInputScalingFactors;
|
||||
TfLiteTensor* input_sf = GetTemporary(context, node, kInputScalingFactors);
|
||||
TfLiteTensor* input_sf;
|
||||
TF_LITE_ENSURE_OK(
|
||||
context,
|
||||
GetTemporarySafe(context, node, kInputScalingFactors, &input_sf));
|
||||
input_sf->type = kTfLiteFloat32;
|
||||
input_sf->allocation_type = kTfLiteArenaRw;
|
||||
int scaling_dims[1] = {n_batch};
|
||||
@ -717,8 +760,10 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
}
|
||||
node->temporaries->data[kAuxInputScalingFactors] =
|
||||
op_data->scratch_tensor_index + kAuxInputScalingFactors;
|
||||
TfLiteTensor* aux_input_sf =
|
||||
GetTemporary(context, node, kAuxInputScalingFactors);
|
||||
TfLiteTensor* aux_input_sf;
|
||||
TF_LITE_ENSURE_OK(context,
|
||||
GetTemporarySafe(context, node, kAuxInputScalingFactors,
|
||||
&aux_input_sf));
|
||||
aux_input_sf->type = kTfLiteFloat32;
|
||||
aux_input_sf->allocation_type = kTfLiteArenaRw;
|
||||
if (!TfLiteIntArrayEqualsArray(aux_input_sf->dims, 1, scaling_dims)) {
|
||||
@ -729,8 +774,10 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
}
|
||||
node->temporaries->data[kOutputStateScalingFactors] =
|
||||
op_data->scratch_tensor_index + kOutputStateScalingFactors;
|
||||
TfLiteTensor* output_state_sf =
|
||||
GetTemporary(context, node, kOutputStateScalingFactors);
|
||||
TfLiteTensor* output_state_sf;
|
||||
TF_LITE_ENSURE_OK(
|
||||
context, GetTemporarySafe(context, node, kOutputStateScalingFactors,
|
||||
&output_state_sf));
|
||||
output_state_sf->type = kTfLiteFloat32;
|
||||
output_state_sf->allocation_type = kTfLiteArenaRw;
|
||||
if (!TfLiteIntArrayEqualsArray(output_state_sf->dims, 1, scaling_dims)) {
|
||||
@ -741,8 +788,10 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
}
|
||||
node->temporaries->data[kProductScalingFactors] =
|
||||
op_data->scratch_tensor_index + kProductScalingFactors;
|
||||
TfLiteTensor* prod_scaling_factors =
|
||||
GetTemporary(context, node, kProductScalingFactors);
|
||||
TfLiteTensor* prod_scaling_factors;
|
||||
TF_LITE_ENSURE_OK(context,
|
||||
GetTemporarySafe(context, node, kProductScalingFactors,
|
||||
&prod_scaling_factors));
|
||||
prod_scaling_factors->type = kTfLiteFloat32;
|
||||
prod_scaling_factors->allocation_type = kTfLiteArenaRw;
|
||||
if (!TfLiteIntArrayEqualsArray(prod_scaling_factors->dims, 1,
|
||||
@ -758,8 +807,10 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
// this is used for diagonal matrices, only need to store n_cell values.
|
||||
node->temporaries->data[kRecoveredCellWeights] =
|
||||
op_data->scratch_tensor_index + kRecoveredCellWeights;
|
||||
TfLiteTensor* recovered_cell_weights =
|
||||
GetTemporary(context, node, kRecoveredCellWeights);
|
||||
TfLiteTensor* recovered_cell_weights;
|
||||
TF_LITE_ENSURE_OK(context,
|
||||
GetTemporarySafe(context, node, kRecoveredCellWeights,
|
||||
&recovered_cell_weights));
|
||||
recovered_cell_weights->type = kTfLiteFloat32;
|
||||
recovered_cell_weights->allocation_type = kTfLiteArenaRw;
|
||||
int recovered_cell_dims[1] = {n_fw_cell};
|
||||
@ -775,8 +826,10 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
// Allocate a temporary tensor to store the accumulated int32 values.
|
||||
node->temporaries->data[kAccumScratchBuffer] =
|
||||
op_data->scratch_tensor_index + kAccumScratchBuffer;
|
||||
TfLiteTensor* accum_scratch =
|
||||
GetTemporary(context, node, kAccumScratchBuffer);
|
||||
TfLiteTensor* accum_scratch;
|
||||
TF_LITE_ENSURE_OK(
|
||||
context,
|
||||
GetTemporarySafe(context, node, kAccumScratchBuffer, &accum_scratch));
|
||||
accum_scratch->type = kTfLiteInt32;
|
||||
accum_scratch->allocation_type = kTfLiteArenaRw;
|
||||
int n_cell = std::max(n_fw_cell, n_bw_cell);
|
||||
@ -797,7 +850,9 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
// Allocate temporary tensors for storing zero-points.
|
||||
node->temporaries->data[kInputZeroPoints] =
|
||||
op_data->scratch_tensor_index + kInputZeroPoints;
|
||||
TfLiteTensor* input_zp = GetTemporary(context, node, kInputZeroPoints);
|
||||
TfLiteTensor* input_zp;
|
||||
TF_LITE_ENSURE_OK(
|
||||
context, GetTemporarySafe(context, node, kInputZeroPoints, &input_zp));
|
||||
input_zp->type = kTfLiteFloat32;
|
||||
input_zp->allocation_type = kTfLiteArenaRw;
|
||||
if (!TfLiteIntArrayEqualsArray(input_zp->dims, 1, scaling_dims)) {
|
||||
@ -808,8 +863,10 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
}
|
||||
node->temporaries->data[kAuxInputZeroPoints] =
|
||||
op_data->scratch_tensor_index + kAuxInputZeroPoints;
|
||||
TfLiteTensor* aux_input_zp =
|
||||
GetTemporary(context, node, kAuxInputZeroPoints);
|
||||
TfLiteTensor* aux_input_zp;
|
||||
TF_LITE_ENSURE_OK(
|
||||
context,
|
||||
GetTemporarySafe(context, node, kAuxInputZeroPoints, &aux_input_zp));
|
||||
aux_input_zp->type = kTfLiteFloat32;
|
||||
aux_input_zp->allocation_type = kTfLiteArenaRw;
|
||||
if (!TfLiteIntArrayEqualsArray(aux_input_zp->dims, 1, scaling_dims)) {
|
||||
@ -820,8 +877,10 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
}
|
||||
node->temporaries->data[kOutputStateZeroPoints] =
|
||||
op_data->scratch_tensor_index + kOutputStateZeroPoints;
|
||||
TfLiteTensor* output_state_zp =
|
||||
GetTemporary(context, node, kOutputStateZeroPoints);
|
||||
TfLiteTensor* output_state_zp;
|
||||
TF_LITE_ENSURE_OK(context,
|
||||
GetTemporarySafe(context, node, kOutputStateZeroPoints,
|
||||
&output_state_zp));
|
||||
output_state_zp->type = kTfLiteFloat32;
|
||||
output_state_zp->allocation_type = kTfLiteArenaRw;
|
||||
if (!TfLiteIntArrayEqualsArray(output_state_zp->dims, 1, scaling_dims)) {
|
||||
@ -844,7 +903,9 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
}
|
||||
node->temporaries->data[kFwRowSums] =
|
||||
op_data->scratch_tensor_index + kFwRowSums;
|
||||
TfLiteTensor* fw_row_sums = GetTemporary(context, node, kFwRowSums);
|
||||
TfLiteTensor* fw_row_sums;
|
||||
TF_LITE_ENSURE_OK(
|
||||
context, GetTemporarySafe(context, node, kFwRowSums, &fw_row_sums));
|
||||
fw_row_sums->type = kTfLiteInt32;
|
||||
fw_row_sums->allocation_type = kTfLiteArenaRwPersistent;
|
||||
int fw_row_sums_dims[2] = {fw_row_sums_rows, n_fw_cell};
|
||||
@ -867,7 +928,9 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
}
|
||||
node->temporaries->data[kBwRowSums] =
|
||||
op_data->scratch_tensor_index + kBwRowSums;
|
||||
TfLiteTensor* bw_row_sums = GetTemporary(context, node, kBwRowSums);
|
||||
TfLiteTensor* bw_row_sums;
|
||||
TF_LITE_ENSURE_OK(
|
||||
context, GetTemporarySafe(context, node, kBwRowSums, &bw_row_sums));
|
||||
bw_row_sums->type = kTfLiteInt32;
|
||||
bw_row_sums->allocation_type = kTfLiteArenaRwPersistent;
|
||||
int bw_row_sums_dims[2] = {bw_row_sums_rows, n_bw_cell};
|
||||
@ -884,8 +947,10 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
if (has_aux_input) {
|
||||
node->temporaries->data[kAuxInputQuantized] =
|
||||
op_data->scratch_tensor_index + kAuxInputQuantized;
|
||||
TfLiteTensor* aux_input_quantized =
|
||||
GetTemporary(context, node, kAuxInputQuantized);
|
||||
TfLiteTensor* aux_input_quantized;
|
||||
TF_LITE_ENSURE_OK(context,
|
||||
GetTemporarySafe(context, node, kAuxInputQuantized,
|
||||
&aux_input_quantized));
|
||||
aux_input_quantized->type = fw_input_to_output_weights->type;
|
||||
aux_input_quantized->allocation_type = kTfLiteArenaRw;
|
||||
if (!TfLiteIntArrayEqual(aux_input_quantized->dims, aux_input->dims)) {
|
||||
@ -906,26 +971,39 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
||||
node->builtin_data);
|
||||
auto* op_data = reinterpret_cast<OpData*>(node->user_data);
|
||||
// Input tensor.
|
||||
const TfLiteTensor* input = GetInput(context, node, kInputTensor);
|
||||
const TfLiteTensor* input;
|
||||
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputTensor, &input));
|
||||
|
||||
// Tensors for the forward cell.
|
||||
const TfLiteTensor* fw_input_to_input_weights =
|
||||
GetOptionalInputTensor(context, node, kFwInputToInputWeightsTensor);
|
||||
const TfLiteTensor* fw_input_to_forget_weights =
|
||||
GetInput(context, node, kFwInputToForgetWeightsTensor);
|
||||
const TfLiteTensor* fw_input_to_cell_weights =
|
||||
GetInput(context, node, kFwInputToCellWeightsTensor);
|
||||
const TfLiteTensor* fw_input_to_output_weights =
|
||||
GetInput(context, node, kFwInputToOutputWeightsTensor);
|
||||
const TfLiteTensor* fw_input_to_forget_weights;
|
||||
TF_LITE_ENSURE_OK(context,
|
||||
GetInputSafe(context, node, kFwInputToForgetWeightsTensor,
|
||||
&fw_input_to_forget_weights));
|
||||
const TfLiteTensor* fw_input_to_cell_weights;
|
||||
TF_LITE_ENSURE_OK(context,
|
||||
GetInputSafe(context, node, kFwInputToCellWeightsTensor,
|
||||
&fw_input_to_cell_weights));
|
||||
const TfLiteTensor* fw_input_to_output_weights;
|
||||
TF_LITE_ENSURE_OK(context,
|
||||
GetInputSafe(context, node, kFwInputToOutputWeightsTensor,
|
||||
&fw_input_to_output_weights));
|
||||
|
||||
const TfLiteTensor* fw_recurrent_to_input_weights =
|
||||
GetOptionalInputTensor(context, node, kFwRecurrentToInputWeightsTensor);
|
||||
const TfLiteTensor* fw_recurrent_to_forget_weights =
|
||||
GetInput(context, node, kFwRecurrentToForgetWeightsTensor);
|
||||
const TfLiteTensor* fw_recurrent_to_cell_weights =
|
||||
GetInput(context, node, kFwRecurrentToCellWeightsTensor);
|
||||
const TfLiteTensor* fw_recurrent_to_output_weights =
|
||||
GetInput(context, node, kFwRecurrentToOutputWeightsTensor);
|
||||
const TfLiteTensor* fw_recurrent_to_forget_weights;
|
||||
TF_LITE_ENSURE_OK(
|
||||
context, GetInputSafe(context, node, kFwRecurrentToForgetWeightsTensor,
|
||||
&fw_recurrent_to_forget_weights));
|
||||
const TfLiteTensor* fw_recurrent_to_cell_weights;
|
||||
TF_LITE_ENSURE_OK(context,
|
||||
GetInputSafe(context, node, kFwRecurrentToCellWeightsTensor,
|
||||
&fw_recurrent_to_cell_weights));
|
||||
const TfLiteTensor* fw_recurrent_to_output_weights;
|
||||
TF_LITE_ENSURE_OK(
|
||||
context, GetInputSafe(context, node, kFwRecurrentToOutputWeightsTensor,
|
||||
&fw_recurrent_to_output_weights));
|
||||
|
||||
const TfLiteTensor* fw_cell_to_input_weights =
|
||||
GetOptionalInputTensor(context, node, kFwCellToInputWeightsTensor);
|
||||
@ -936,12 +1014,17 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
||||
|
||||
const TfLiteTensor* fw_input_gate_bias =
|
||||
GetOptionalInputTensor(context, node, kFwInputGateBiasTensor);
|
||||
const TfLiteTensor* fw_forget_gate_bias =
|
||||
GetInput(context, node, kFwForgetGateBiasTensor);
|
||||
const TfLiteTensor* fw_cell_gate_bias =
|
||||
GetInput(context, node, kFwCellGateBiasTensor);
|
||||
const TfLiteTensor* fw_output_gate_bias =
|
||||
GetInput(context, node, kFwOutputGateBiasTensor);
|
||||
const TfLiteTensor* fw_forget_gate_bias;
|
||||
TF_LITE_ENSURE_OK(context,
|
||||
GetInputSafe(context, node, kFwForgetGateBiasTensor,
|
||||
&fw_forget_gate_bias));
|
||||
const TfLiteTensor* fw_cell_gate_bias;
|
||||
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kFwCellGateBiasTensor,
|
||||
&fw_cell_gate_bias));
|
||||
const TfLiteTensor* fw_output_gate_bias;
|
||||
TF_LITE_ENSURE_OK(context,
|
||||
GetInputSafe(context, node, kFwOutputGateBiasTensor,
|
||||
&fw_output_gate_bias));
|
||||
|
||||
const TfLiteTensor* fw_projection_weights =
|
||||
GetOptionalInputTensor(context, node, kFwProjectionWeightsTensor);
|
||||
@ -950,30 +1033,44 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
||||
|
||||
TfLiteTensor* fw_activation_state =
|
||||
GetVariableInput(context, node, kFwInputActivationStateTensor);
|
||||
TF_LITE_ENSURE(context, fw_activation_state != nullptr);
|
||||
TFLITE_DCHECK(fw_activation_state != nullptr);
|
||||
TfLiteTensor* fw_cell_state =
|
||||
GetVariableInput(context, node, kFwInputCellStateTensor);
|
||||
TF_LITE_ENSURE(context, fw_cell_state != nullptr);
|
||||
TfLiteTensor* fw_output = GetOutput(context, node, kFwOutputTensor);
|
||||
TFLITE_DCHECK(fw_cell_state != nullptr);
|
||||
TfLiteTensor* fw_output;
|
||||
TF_LITE_ENSURE_OK(context,
|
||||
GetOutputSafe(context, node, kFwOutputTensor, &fw_output));
|
||||
|
||||
// Tensors for the backward cell.
|
||||
const TfLiteTensor* bw_input_to_input_weights =
|
||||
GetOptionalInputTensor(context, node, kBwInputToInputWeightsTensor);
|
||||
const TfLiteTensor* bw_input_to_forget_weights =
|
||||
GetInput(context, node, kBwInputToForgetWeightsTensor);
|
||||
const TfLiteTensor* bw_input_to_cell_weights =
|
||||
GetInput(context, node, kBwInputToCellWeightsTensor);
|
||||
const TfLiteTensor* bw_input_to_output_weights =
|
||||
GetInput(context, node, kBwInputToOutputWeightsTensor);
|
||||
const TfLiteTensor* bw_input_to_forget_weights;
|
||||
TF_LITE_ENSURE_OK(context,
|
||||
GetInputSafe(context, node, kBwInputToForgetWeightsTensor,
|
||||
&bw_input_to_forget_weights));
|
||||
const TfLiteTensor* bw_input_to_cell_weights;
|
||||
TF_LITE_ENSURE_OK(context,
|
||||
GetInputSafe(context, node, kBwInputToCellWeightsTensor,
|
||||
&bw_input_to_cell_weights));
|
||||
const TfLiteTensor* bw_input_to_output_weights;
|
||||
TF_LITE_ENSURE_OK(context,
|
||||
GetInputSafe(context, node, kBwInputToOutputWeightsTensor,
|
||||
&bw_input_to_output_weights));
|
||||
|
||||
const TfLiteTensor* bw_recurrent_to_input_weights =
|
||||
GetOptionalInputTensor(context, node, kBwRecurrentToInputWeightsTensor);
|
||||
const TfLiteTensor* bw_recurrent_to_forget_weights =
|
||||
GetInput(context, node, kBwRecurrentToForgetWeightsTensor);
|
||||
const TfLiteTensor* bw_recurrent_to_cell_weights =
|
||||
GetInput(context, node, kBwRecurrentToCellWeightsTensor);
|
||||
const TfLiteTensor* bw_recurrent_to_output_weights =
|
||||
GetInput(context, node, kBwRecurrentToOutputWeightsTensor);
|
||||
const TfLiteTensor* bw_recurrent_to_forget_weights;
|
||||
TF_LITE_ENSURE_OK(
|
||||
context, GetInputSafe(context, node, kBwRecurrentToForgetWeightsTensor,
|
||||
&bw_recurrent_to_forget_weights));
|
||||
const TfLiteTensor* bw_recurrent_to_cell_weights;
|
||||
TF_LITE_ENSURE_OK(context,
|
||||
GetInputSafe(context, node, kBwRecurrentToCellWeightsTensor,
|
||||
&bw_recurrent_to_cell_weights));
|
||||
const TfLiteTensor* bw_recurrent_to_output_weights;
|
||||
TF_LITE_ENSURE_OK(
|
||||
context, GetInputSafe(context, node, kBwRecurrentToOutputWeightsTensor,
|
||||
&bw_recurrent_to_output_weights));
|
||||
|
||||
const TfLiteTensor* bw_cell_to_input_weights =
|
||||
GetOptionalInputTensor(context, node, kBwCellToInputWeightsTensor);
|
||||
@ -984,12 +1081,17 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
||||
|
||||
const TfLiteTensor* bw_input_gate_bias =
|
||||
GetOptionalInputTensor(context, node, kBwInputGateBiasTensor);
|
||||
const TfLiteTensor* bw_forget_gate_bias =
|
||||
GetInput(context, node, kBwForgetGateBiasTensor);
|
||||
const TfLiteTensor* bw_cell_gate_bias =
|
||||
GetInput(context, node, kBwCellGateBiasTensor);
|
||||
const TfLiteTensor* bw_output_gate_bias =
|
||||
GetInput(context, node, kBwOutputGateBiasTensor);
|
||||
const TfLiteTensor* bw_forget_gate_bias;
|
||||
TF_LITE_ENSURE_OK(context,
|
||||
GetInputSafe(context, node, kBwForgetGateBiasTensor,
|
||||
&bw_forget_gate_bias));
|
||||
const TfLiteTensor* bw_cell_gate_bias;
|
||||
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kBwCellGateBiasTensor,
|
||||
&bw_cell_gate_bias));
|
||||
const TfLiteTensor* bw_output_gate_bias;
|
||||
TF_LITE_ENSURE_OK(context,
|
||||
GetInputSafe(context, node, kBwOutputGateBiasTensor,
|
||||
&bw_output_gate_bias));
|
||||
|
||||
const TfLiteTensor* bw_projection_weights =
|
||||
GetOptionalInputTensor(context, node, kBwProjectionWeightsTensor);
|
||||
@ -999,19 +1101,21 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
||||
// State tensors.
|
||||
TfLiteTensor* bw_activation_state =
|
||||
GetVariableInput(context, node, kBwInputActivationStateTensor);
|
||||
TF_LITE_ENSURE(context, bw_activation_state != nullptr);
|
||||
TFLITE_DCHECK(bw_activation_state != nullptr);
|
||||
TfLiteTensor* bw_cell_state =
|
||||
GetVariableInput(context, node, kBwInputCellStateTensor);
|
||||
TF_LITE_ENSURE(context, bw_cell_state != nullptr);
|
||||
TFLITE_DCHECK(bw_cell_state != nullptr);
|
||||
TfLiteTensor* bw_output = params->merge_outputs
|
||||
? nullptr
|
||||
: GetOutput(context, node, kBwOutputTensor);
|
||||
|
||||
// Temporary tensors.
|
||||
TfLiteTensor* fw_scratch_buffer =
|
||||
GetTemporary(context, node, kFwScratchBuffer);
|
||||
TfLiteTensor* bw_scratch_buffer =
|
||||
GetTemporary(context, node, kBwScratchBuffer);
|
||||
TfLiteTensor* fw_scratch_buffer;
|
||||
TF_LITE_ENSURE_OK(context, GetTemporarySafe(context, node, kFwScratchBuffer,
|
||||
&fw_scratch_buffer));
|
||||
TfLiteTensor* bw_scratch_buffer;
|
||||
TF_LITE_ENSURE_OK(context, GetTemporarySafe(context, node, kBwScratchBuffer,
|
||||
&bw_scratch_buffer));
|
||||
|
||||
// (Optional) auxiliary inputs.
|
||||
const TfLiteTensor* aux_input =
|
||||
@ -1112,27 +1216,47 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
||||
}
|
||||
case kTfLiteUInt8:
|
||||
case kTfLiteInt8: {
|
||||
TfLiteTensor* input_quantized =
|
||||
GetTemporary(context, node, kInputQuantized);
|
||||
TfLiteTensor* fw_activation_state_quantized =
|
||||
GetTemporary(context, node, kFwActivationStateQuantized);
|
||||
TfLiteTensor* bw_activation_state_quantized =
|
||||
GetTemporary(context, node, kBwActivationStateQuantized);
|
||||
TfLiteTensor* fw_cell_state_quantized =
|
||||
GetTemporary(context, node, kFwCellStateQuantized);
|
||||
TfLiteTensor* bw_cell_state_quantized =
|
||||
GetTemporary(context, node, kBwCellStateQuantized);
|
||||
TfLiteTensor* prod_scaling_factors =
|
||||
GetTemporary(context, node, kProductScalingFactors);
|
||||
TfLiteTensor* recovered_cell_weights =
|
||||
GetTemporary(context, node, kRecoveredCellWeights);
|
||||
TfLiteTensor* input_quantized;
|
||||
TF_LITE_ENSURE_OK(
|
||||
context,
|
||||
GetTemporarySafe(context, node, kInputQuantized, &input_quantized));
|
||||
TfLiteTensor* fw_activation_state_quantized;
|
||||
TF_LITE_ENSURE_OK(
|
||||
context, GetTemporarySafe(context, node, kFwActivationStateQuantized,
|
||||
&fw_activation_state_quantized));
|
||||
TfLiteTensor* bw_activation_state_quantized;
|
||||
TF_LITE_ENSURE_OK(
|
||||
context, GetTemporarySafe(context, node, kBwActivationStateQuantized,
|
||||
&bw_activation_state_quantized));
|
||||
TfLiteTensor* fw_cell_state_quantized;
|
||||
TF_LITE_ENSURE_OK(context,
|
||||
GetTemporarySafe(context, node, kFwCellStateQuantized,
|
||||
&fw_cell_state_quantized));
|
||||
TfLiteTensor* bw_cell_state_quantized;
|
||||
TF_LITE_ENSURE_OK(context,
|
||||
GetTemporarySafe(context, node, kBwCellStateQuantized,
|
||||
&bw_cell_state_quantized));
|
||||
TfLiteTensor* prod_scaling_factors;
|
||||
TF_LITE_ENSURE_OK(context,
|
||||
GetTemporarySafe(context, node, kProductScalingFactors,
|
||||
&prod_scaling_factors));
|
||||
TfLiteTensor* recovered_cell_weights;
|
||||
TF_LITE_ENSURE_OK(context,
|
||||
GetTemporarySafe(context, node, kRecoveredCellWeights,
|
||||
&recovered_cell_weights));
|
||||
TfLiteTensor* aux_input_quantized =
|
||||
use_aux_input ? GetTemporary(context, node, kAuxInputQuantized)
|
||||
: nullptr;
|
||||
TfLiteTensor* accum_scratch =
|
||||
GetTemporary(context, node, kAccumScratchBuffer);
|
||||
TfLiteTensor* fw_row_sums = GetTemporary(context, node, kFwRowSums);
|
||||
TfLiteTensor* bw_row_sums = GetTemporary(context, node, kBwRowSums);
|
||||
TfLiteTensor* accum_scratch;
|
||||
TF_LITE_ENSURE_OK(
|
||||
context,
|
||||
GetTemporarySafe(context, node, kAccumScratchBuffer, &accum_scratch));
|
||||
TfLiteTensor* fw_row_sums;
|
||||
TF_LITE_ENSURE_OK(
|
||||
context, GetTemporarySafe(context, node, kFwRowSums, &fw_row_sums));
|
||||
TfLiteTensor* bw_row_sums;
|
||||
TF_LITE_ENSURE_OK(
|
||||
context, GetTemporarySafe(context, node, kBwRowSums, &bw_row_sums));
|
||||
const int fw_row_sums_size = fw_row_sums->dims->data[0];
|
||||
const int bw_row_sums_size = bw_row_sums->dims->data[0];
|
||||
TfLiteStatus fw_pass_status = lstm_eval::EvalHybrid(
|
||||
|
@ -97,21 +97,34 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
TF_LITE_ENSURE_EQ(context, node->outputs->size,
|
||||
params->merge_outputs ? 1 : 2);
|
||||
|
||||
const TfLiteTensor* input = GetInput(context, node, kInputTensor);
|
||||
const TfLiteTensor* fw_input_weights =
|
||||
GetInput(context, node, kFwWeightsTensor);
|
||||
const TfLiteTensor* fw_recurrent_weights =
|
||||
GetInput(context, node, kFwRecurrentWeightsTensor);
|
||||
const TfLiteTensor* fw_bias = GetInput(context, node, kFwBiasTensor);
|
||||
const TfLiteTensor* fw_hidden_state =
|
||||
GetInput(context, node, kFwHiddenStateTensor);
|
||||
const TfLiteTensor* bw_input_weights =
|
||||
GetInput(context, node, kBwWeightsTensor);
|
||||
const TfLiteTensor* bw_recurrent_weights =
|
||||
GetInput(context, node, kBwRecurrentWeightsTensor);
|
||||
const TfLiteTensor* bw_bias = GetInput(context, node, kBwBiasTensor);
|
||||
const TfLiteTensor* bw_hidden_state =
|
||||
GetInput(context, node, kBwHiddenStateTensor);
|
||||
const TfLiteTensor* input;
|
||||
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputTensor, &input));
|
||||
const TfLiteTensor* fw_input_weights;
|
||||
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kFwWeightsTensor,
|
||||
&fw_input_weights));
|
||||
const TfLiteTensor* fw_recurrent_weights;
|
||||
TF_LITE_ENSURE_OK(context,
|
||||
GetInputSafe(context, node, kFwRecurrentWeightsTensor,
|
||||
&fw_recurrent_weights));
|
||||
const TfLiteTensor* fw_bias;
|
||||
TF_LITE_ENSURE_OK(context,
|
||||
GetInputSafe(context, node, kFwBiasTensor, &fw_bias));
|
||||
const TfLiteTensor* fw_hidden_state;
|
||||
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kFwHiddenStateTensor,
|
||||
&fw_hidden_state));
|
||||
const TfLiteTensor* bw_input_weights;
|
||||
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kBwWeightsTensor,
|
||||
&bw_input_weights));
|
||||
const TfLiteTensor* bw_recurrent_weights;
|
||||
TF_LITE_ENSURE_OK(context,
|
||||
GetInputSafe(context, node, kBwRecurrentWeightsTensor,
|
||||
&bw_recurrent_weights));
|
||||
const TfLiteTensor* bw_bias;
|
||||
TF_LITE_ENSURE_OK(context,
|
||||
GetInputSafe(context, node, kBwBiasTensor, &bw_bias));
|
||||
const TfLiteTensor* bw_hidden_state;
|
||||
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kBwHiddenStateTensor,
|
||||
&bw_hidden_state));
|
||||
|
||||
const TfLiteTensor* aux_input =
|
||||
GetOptionalInputTensor(context, node, kAuxInputTensor);
|
||||
@ -186,8 +199,9 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
|
||||
node->temporaries->data[kInputQuantized] =
|
||||
op_data->scratch_tensor_index + kInputQuantized;
|
||||
TfLiteTensor* input_quantized =
|
||||
GetTemporary(context, node, kInputQuantized);
|
||||
TfLiteTensor* input_quantized;
|
||||
TF_LITE_ENSURE_OK(context, GetTemporarySafe(context, node, kInputQuantized,
|
||||
&input_quantized));
|
||||
input_quantized->type = fw_input_weights->type;
|
||||
input_quantized->allocation_type = kTfLiteArenaRw;
|
||||
if (!TfLiteIntArrayEqual(input_quantized->dims, input->dims)) {
|
||||
@ -198,8 +212,10 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
|
||||
node->temporaries->data[kFwHiddenStateQuantized] =
|
||||
op_data->scratch_tensor_index + kFwHiddenStateQuantized;
|
||||
TfLiteTensor* fw_hidden_state_quantized =
|
||||
GetTemporary(context, node, kFwHiddenStateQuantized);
|
||||
TfLiteTensor* fw_hidden_state_quantized;
|
||||
TF_LITE_ENSURE_OK(context,
|
||||
GetTemporarySafe(context, node, kFwHiddenStateQuantized,
|
||||
&fw_hidden_state_quantized));
|
||||
fw_hidden_state_quantized->type = fw_input_weights->type;
|
||||
fw_hidden_state_quantized->allocation_type = kTfLiteArenaRw;
|
||||
if (!TfLiteIntArrayEqual(fw_hidden_state_quantized->dims,
|
||||
@ -213,8 +229,10 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
|
||||
node->temporaries->data[kBwHiddenStateQuantized] =
|
||||
op_data->scratch_tensor_index + kBwHiddenStateQuantized;
|
||||
TfLiteTensor* bw_hidden_state_quantized =
|
||||
GetTemporary(context, node, kBwHiddenStateQuantized);
|
||||
TfLiteTensor* bw_hidden_state_quantized;
|
||||
TF_LITE_ENSURE_OK(context,
|
||||
GetTemporarySafe(context, node, kBwHiddenStateQuantized,
|
||||
&bw_hidden_state_quantized));
|
||||
bw_hidden_state_quantized->type = fw_input_weights->type;
|
||||
bw_hidden_state_quantized->allocation_type = kTfLiteArenaRw;
|
||||
if (!TfLiteIntArrayEqual(bw_hidden_state_quantized->dims,
|
||||
@ -229,8 +247,9 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
// Allocate temporary tensors to store scaling factors of quantization.
|
||||
node->temporaries->data[kScalingFactors] =
|
||||
op_data->scratch_tensor_index + kScalingFactors;
|
||||
TfLiteTensor* scaling_factors =
|
||||
GetTemporary(context, node, kScalingFactors);
|
||||
TfLiteTensor* scaling_factors;
|
||||
TF_LITE_ENSURE_OK(context, GetTemporarySafe(context, node, kScalingFactors,
|
||||
&scaling_factors));
|
||||
scaling_factors->type = kTfLiteFloat32;
|
||||
scaling_factors->allocation_type = kTfLiteArenaRw;
|
||||
int scaling_dims[1] = {batch_size};
|
||||
@ -242,7 +261,9 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
}
|
||||
node->temporaries->data[kAccumScratch] =
|
||||
op_data->scratch_tensor_index + kAccumScratch;
|
||||
TfLiteTensor* accum_scratch = GetTemporary(context, node, kAccumScratch);
|
||||
TfLiteTensor* accum_scratch;
|
||||
TF_LITE_ENSURE_OK(context, GetTemporarySafe(context, node, kAccumScratch,
|
||||
&accum_scratch));
|
||||
accum_scratch->type = kTfLiteInt32;
|
||||
accum_scratch->allocation_type = kTfLiteArenaRw;
|
||||
int accum_scratch_dims[2] = {std::max(fw_num_units, bw_num_units),
|
||||
@ -257,8 +278,10 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
}
|
||||
node->temporaries->data[kZeroPoints] =
|
||||
op_data->scratch_tensor_index + kZeroPoints;
|
||||
TfLiteTensor* zero_points =
|
||||
GetTemporary(context, node, /*index=*/kZeroPoints);
|
||||
TfLiteTensor* zero_points;
|
||||
TF_LITE_ENSURE_OK(
|
||||
context,
|
||||
GetTemporarySafe(context, node, /*index=*/kZeroPoints, &zero_points));
|
||||
zero_points->type = kTfLiteInt32;
|
||||
zero_points->allocation_type = kTfLiteArenaRw;
|
||||
int zero_points_dims[1] = {batch_size};
|
||||
@ -271,8 +294,10 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
const int num_row_sums = has_aux_input ? 3 : 2;
|
||||
node->temporaries->data[kFwRowSums] =
|
||||
op_data->scratch_tensor_index + kFwRowSums;
|
||||
TfLiteTensor* fw_row_sums =
|
||||
GetTemporary(context, node, /*index=*/kFwRowSums);
|
||||
TfLiteTensor* fw_row_sums;
|
||||
TF_LITE_ENSURE_OK(
|
||||
context,
|
||||
GetTemporarySafe(context, node, /*index=*/kFwRowSums, &fw_row_sums));
|
||||
fw_row_sums->type = kTfLiteInt32;
|
||||
fw_row_sums->allocation_type = kTfLiteArenaRwPersistent;
|
||||
int fw_row_sums_dims[2] = {num_row_sums, fw_num_units};
|
||||
@ -285,8 +310,10 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
}
|
||||
node->temporaries->data[kBwRowSums] =
|
||||
op_data->scratch_tensor_index + kBwRowSums;
|
||||
TfLiteTensor* bw_row_sums = GetTemporary(context, node,
|
||||
/*index=*/kBwRowSums);
|
||||
TfLiteTensor* bw_row_sums;
|
||||
TF_LITE_ENSURE_OK(
|
||||
context,
|
||||
GetTemporarySafe(context, node, /*index=*/kBwRowSums, &bw_row_sums));
|
||||
bw_row_sums->type = kTfLiteInt32;
|
||||
bw_row_sums->allocation_type = kTfLiteArenaRwPersistent;
|
||||
int bw_row_sums_dims[2] = {num_row_sums, bw_num_units};
|
||||
@ -300,8 +327,10 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
if (has_aux_input) {
|
||||
node->temporaries->data[kAuxInputQuantized] =
|
||||
op_data->scratch_tensor_index + kAuxInputQuantized;
|
||||
TfLiteTensor* aux_input_quantized =
|
||||
GetTemporary(context, node, kAuxInputQuantized);
|
||||
TfLiteTensor* aux_input_quantized;
|
||||
TF_LITE_ENSURE_OK(context,
|
||||
GetTemporarySafe(context, node, kAuxInputQuantized,
|
||||
&aux_input_quantized));
|
||||
aux_input_quantized->type = fw_input_weights->type;
|
||||
aux_input_quantized->allocation_type = kTfLiteArenaRw;
|
||||
if (!TfLiteIntArrayEqual(aux_input_quantized->dims, aux_input->dims)) {
|
||||
@ -315,7 +344,9 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
}
|
||||
|
||||
// Resize outputs.
|
||||
TfLiteTensor* fw_output = GetOutput(context, node, kFwOutputTensor);
|
||||
TfLiteTensor* fw_output;
|
||||
TF_LITE_ENSURE_OK(context,
|
||||
GetOutputSafe(context, node, kFwOutputTensor, &fw_output));
|
||||
TfLiteIntArray* fw_output_size_array = TfLiteIntArrayCreate(3);
|
||||
fw_output_size_array->data[0] = (time_major) ? max_time : batch_size;
|
||||
fw_output_size_array->data[1] = (time_major) ? batch_size : max_time;
|
||||
@ -324,7 +355,9 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
TF_LITE_ENSURE_OK(
|
||||
context, context->ResizeTensor(context, fw_output, fw_output_size_array));
|
||||
if (!params->merge_outputs) {
|
||||
TfLiteTensor* bw_output = GetOutput(context, node, kBwOutputTensor);
|
||||
TfLiteTensor* bw_output;
|
||||
TF_LITE_ENSURE_OK(
|
||||
context, GetOutputSafe(context, node, kBwOutputTensor, &bw_output));
|
||||
TfLiteIntArray* bw_output_size_array = TfLiteIntArrayCreate(3);
|
||||
bw_output_size_array->data[0] = batch_size;
|
||||
bw_output_size_array->data[1] = max_time;
|
||||
@ -678,17 +711,28 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
||||
const auto* params = reinterpret_cast<TfLiteBidirectionalSequenceRNNParams*>(
|
||||
node->builtin_data);
|
||||
|
||||
const TfLiteTensor* input = GetInput(context, node, kInputTensor);
|
||||
const TfLiteTensor* fw_input_weights =
|
||||
GetInput(context, node, kFwWeightsTensor);
|
||||
const TfLiteTensor* fw_recurrent_weights =
|
||||
GetInput(context, node, kFwRecurrentWeightsTensor);
|
||||
const TfLiteTensor* fw_bias = GetInput(context, node, kFwBiasTensor);
|
||||
const TfLiteTensor* bw_input_weights =
|
||||
GetInput(context, node, kBwWeightsTensor);
|
||||
const TfLiteTensor* bw_recurrent_weights =
|
||||
GetInput(context, node, kBwRecurrentWeightsTensor);
|
||||
const TfLiteTensor* bw_bias = GetInput(context, node, kBwBiasTensor);
|
||||
const TfLiteTensor* input;
|
||||
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputTensor, &input));
|
||||
const TfLiteTensor* fw_input_weights;
|
||||
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kFwWeightsTensor,
|
||||
&fw_input_weights));
|
||||
const TfLiteTensor* fw_recurrent_weights;
|
||||
TF_LITE_ENSURE_OK(context,
|
||||
GetInputSafe(context, node, kFwRecurrentWeightsTensor,
|
||||
&fw_recurrent_weights));
|
||||
const TfLiteTensor* fw_bias;
|
||||
TF_LITE_ENSURE_OK(context,
|
||||
GetInputSafe(context, node, kFwBiasTensor, &fw_bias));
|
||||
const TfLiteTensor* bw_input_weights;
|
||||
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kBwWeightsTensor,
|
||||
&bw_input_weights));
|
||||
const TfLiteTensor* bw_recurrent_weights;
|
||||
TF_LITE_ENSURE_OK(context,
|
||||
GetInputSafe(context, node, kBwRecurrentWeightsTensor,
|
||||
&bw_recurrent_weights));
|
||||
const TfLiteTensor* bw_bias;
|
||||
TF_LITE_ENSURE_OK(context,
|
||||
GetInputSafe(context, node, kBwBiasTensor, &bw_bias));
|
||||
|
||||
// Get auxiliary inputs.
|
||||
const TfLiteTensor* aux_input =
|
||||
@ -700,12 +744,14 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
||||
|
||||
TfLiteTensor* fw_hidden_state =
|
||||
GetVariableInput(context, node, kFwHiddenStateTensor);
|
||||
TF_LITE_ENSURE(context, fw_hidden_state != nullptr);
|
||||
TFLITE_DCHECK(fw_hidden_state != nullptr);
|
||||
TfLiteTensor* bw_hidden_state =
|
||||
GetVariableInput(context, node, kBwHiddenStateTensor);
|
||||
TF_LITE_ENSURE(context, bw_hidden_state != nullptr);
|
||||
TFLITE_DCHECK(bw_hidden_state != nullptr);
|
||||
|
||||
TfLiteTensor* fw_output = GetOutput(context, node, kFwOutputTensor);
|
||||
TfLiteTensor* fw_output;
|
||||
TF_LITE_ENSURE_OK(context,
|
||||
GetOutputSafe(context, node, kFwOutputTensor, &fw_output));
|
||||
TfLiteTensor* bw_output = params->merge_outputs
|
||||
? nullptr
|
||||
: GetOutput(context, node, kBwOutputTensor);
|
||||
@ -741,18 +787,34 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
||||
bw_hidden_state, bw_output);
|
||||
case kTfLiteUInt8:
|
||||
case kTfLiteInt8: {
|
||||
TfLiteTensor* input_quantized =
|
||||
GetTemporary(context, node, kInputQuantized);
|
||||
TfLiteTensor* fw_hidden_state_quantized =
|
||||
GetTemporary(context, node, kFwHiddenStateQuantized);
|
||||
TfLiteTensor* bw_hidden_state_quantized =
|
||||
GetTemporary(context, node, kBwHiddenStateQuantized);
|
||||
TfLiteTensor* scaling_factors =
|
||||
GetTemporary(context, node, kScalingFactors);
|
||||
TfLiteTensor* zero_points = GetTemporary(context, node, kZeroPoints);
|
||||
TfLiteTensor* accum_scratch = GetTemporary(context, node, kAccumScratch);
|
||||
TfLiteTensor* fw_row_sums = GetTemporary(context, node, kFwRowSums);
|
||||
TfLiteTensor* bw_row_sums = GetTemporary(context, node, kBwRowSums);
|
||||
TfLiteTensor* input_quantized;
|
||||
TF_LITE_ENSURE_OK(
|
||||
context,
|
||||
GetTemporarySafe(context, node, kInputQuantized, &input_quantized));
|
||||
TfLiteTensor* fw_hidden_state_quantized;
|
||||
TF_LITE_ENSURE_OK(context,
|
||||
GetTemporarySafe(context, node, kFwHiddenStateQuantized,
|
||||
&fw_hidden_state_quantized));
|
||||
TfLiteTensor* bw_hidden_state_quantized;
|
||||
TF_LITE_ENSURE_OK(context,
|
||||
GetTemporarySafe(context, node, kBwHiddenStateQuantized,
|
||||
&bw_hidden_state_quantized));
|
||||
TfLiteTensor* scaling_factors;
|
||||
TF_LITE_ENSURE_OK(
|
||||
context,
|
||||
GetTemporarySafe(context, node, kScalingFactors, &scaling_factors));
|
||||
TfLiteTensor* zero_points;
|
||||
TF_LITE_ENSURE_OK(
|
||||
context, GetTemporarySafe(context, node, kZeroPoints, &zero_points));
|
||||
TfLiteTensor* accum_scratch;
|
||||
TF_LITE_ENSURE_OK(context, GetTemporarySafe(context, node, kAccumScratch,
|
||||
&accum_scratch));
|
||||
TfLiteTensor* fw_row_sums;
|
||||
TF_LITE_ENSURE_OK(
|
||||
context, GetTemporarySafe(context, node, kFwRowSums, &fw_row_sums));
|
||||
TfLiteTensor* bw_row_sums;
|
||||
TF_LITE_ENSURE_OK(
|
||||
context, GetTemporarySafe(context, node, kBwRowSums, &bw_row_sums));
|
||||
TfLiteTensor* aux_input_quantized =
|
||||
use_aux_input ? GetTemporary(context, node, kAuxInputQuantized)
|
||||
: nullptr;
|
||||
|
@ -32,8 +32,11 @@ constexpr int kOutputTensor = 0;
|
||||
TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
TF_LITE_ENSURE_EQ(context, NumInputs(node), 1);
|
||||
TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
|
||||
const TfLiteTensor* input = GetInput(context, node, kInputTensor);
|
||||
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
|
||||
const TfLiteTensor* input;
|
||||
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputTensor, &input));
|
||||
TfLiteTensor* output;
|
||||
TF_LITE_ENSURE_OK(context,
|
||||
GetOutputSafe(context, node, kOutputTensor, &output));
|
||||
|
||||
// TODO(ahentz): these two checks would make the new implementation
|
||||
// incompatible with some existing models, where params is not specified. It
|
||||
@ -98,8 +101,11 @@ TfLiteStatus copyToTensor(TfLiteContext* context, const FromT* in,
|
||||
}
|
||||
|
||||
TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
||||
const TfLiteTensor* input = GetInput(context, node, kInputTensor);
|
||||
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
|
||||
const TfLiteTensor* input;
|
||||
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputTensor, &input));
|
||||
TfLiteTensor* output;
|
||||
TF_LITE_ENSURE_OK(context,
|
||||
GetOutputSafe(context, node, kOutputTensor, &output));
|
||||
const int num_elements = NumElements(input);
|
||||
TF_LITE_ENSURE_EQ(context, num_elements, NumElements(output));
|
||||
switch (input->type) {
|
||||
|
@ -29,8 +29,11 @@ constexpr int kInputTensor = 0;
|
||||
constexpr int kOutputTensor = 0;
|
||||
|
||||
TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
const TfLiteTensor* input = GetInput(context, node, kInputTensor);
|
||||
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
|
||||
const TfLiteTensor* input;
|
||||
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputTensor, &input));
|
||||
TfLiteTensor* output;
|
||||
TF_LITE_ENSURE_OK(context,
|
||||
GetOutputSafe(context, node, kOutputTensor, &output));
|
||||
TF_LITE_ENSURE_EQ(context, NumInputs(node), 1);
|
||||
TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
|
||||
TF_LITE_ENSURE_TYPES_EQ(context, input->type, kTfLiteFloat32);
|
||||
@ -40,8 +43,11 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
}
|
||||
|
||||
TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
||||
const TfLiteTensor* input = GetInput(context, node, kInputTensor);
|
||||
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
|
||||
const TfLiteTensor* input;
|
||||
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputTensor, &input));
|
||||
TfLiteTensor* output;
|
||||
TF_LITE_ENSURE_OK(context,
|
||||
GetOutputSafe(context, node, kOutputTensor, &output));
|
||||
if (input->type != kTfLiteFloat32) {
|
||||
TF_LITE_UNSUPPORTED_TYPE(context, input->type, "Ceil");
|
||||
}
|
||||
|
@ -41,9 +41,15 @@ TfLiteStatus ComparisonPrepareCommon(TfLiteContext* context, TfLiteNode* node,
|
||||
TF_LITE_ENSURE_EQ(context, NumInputs(node), 2);
|
||||
TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
|
||||
|
||||
const TfLiteTensor* input1 = GetInput(context, node, kInputTensor1);
|
||||
const TfLiteTensor* input2 = GetInput(context, node, kInputTensor2);
|
||||
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
|
||||
const TfLiteTensor* input1;
|
||||
TF_LITE_ENSURE_OK(context,
|
||||
GetInputSafe(context, node, kInputTensor1, &input1));
|
||||
const TfLiteTensor* input2;
|
||||
TF_LITE_ENSURE_OK(context,
|
||||
GetInputSafe(context, node, kInputTensor2, &input2));
|
||||
TfLiteTensor* output;
|
||||
TF_LITE_ENSURE_OK(context,
|
||||
GetOutputSafe(context, node, kOutputTensor, &output));
|
||||
|
||||
// Don't support string.
|
||||
if (!is_string_allowed) {
|
||||
@ -145,9 +151,15 @@ void ComparisonString(bool (*opname)(const StringRef&, const StringRef&),
|
||||
}
|
||||
|
||||
TfLiteStatus EqualEval(TfLiteContext* context, TfLiteNode* node) {
|
||||
const TfLiteTensor* input1 = GetInput(context, node, kInputTensor1);
|
||||
const TfLiteTensor* input2 = GetInput(context, node, kInputTensor2);
|
||||
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
|
||||
const TfLiteTensor* input1;
|
||||
TF_LITE_ENSURE_OK(context,
|
||||
GetInputSafe(context, node, kInputTensor1, &input1));
|
||||
const TfLiteTensor* input2;
|
||||
TF_LITE_ENSURE_OK(context,
|
||||
GetInputSafe(context, node, kInputTensor2, &input2));
|
||||
TfLiteTensor* output;
|
||||
TF_LITE_ENSURE_OK(context,
|
||||
GetOutputSafe(context, node, kOutputTensor, &output));
|
||||
bool requires_broadcast = !HaveSameShapes(input1, input2);
|
||||
switch (input1->type) {
|
||||
case kTfLiteBool:
|
||||
@ -189,9 +201,15 @@ TfLiteStatus EqualEval(TfLiteContext* context, TfLiteNode* node) {
|
||||
}
|
||||
|
||||
TfLiteStatus NotEqualEval(TfLiteContext* context, TfLiteNode* node) {
|
||||
const TfLiteTensor* input1 = GetInput(context, node, kInputTensor1);
|
||||
const TfLiteTensor* input2 = GetInput(context, node, kInputTensor2);
|
||||
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
|
||||
const TfLiteTensor* input1;
|
||||
TF_LITE_ENSURE_OK(context,
|
||||
GetInputSafe(context, node, kInputTensor1, &input1));
|
||||
const TfLiteTensor* input2;
|
||||
TF_LITE_ENSURE_OK(context,
|
||||
GetInputSafe(context, node, kInputTensor2, &input2));
|
||||
TfLiteTensor* output;
|
||||
TF_LITE_ENSURE_OK(context,
|
||||
GetOutputSafe(context, node, kOutputTensor, &output));
|
||||
bool requires_broadcast = !HaveSameShapes(input1, input2);
|
||||
switch (input1->type) {
|
||||
case kTfLiteBool:
|
||||
@ -233,9 +251,15 @@ TfLiteStatus NotEqualEval(TfLiteContext* context, TfLiteNode* node) {
|
||||
}
|
||||
|
||||
TfLiteStatus GreaterEval(TfLiteContext* context, TfLiteNode* node) {
|
||||
const TfLiteTensor* input1 = GetInput(context, node, kInputTensor1);
|
||||
const TfLiteTensor* input2 = GetInput(context, node, kInputTensor2);
|
||||
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
|
||||
const TfLiteTensor* input1;
|
||||
TF_LITE_ENSURE_OK(context,
|
||||
GetInputSafe(context, node, kInputTensor1, &input1));
|
||||
const TfLiteTensor* input2;
|
||||
TF_LITE_ENSURE_OK(context,
|
||||
GetInputSafe(context, node, kInputTensor2, &input2));
|
||||
TfLiteTensor* output;
|
||||
TF_LITE_ENSURE_OK(context,
|
||||
GetOutputSafe(context, node, kOutputTensor, &output));
|
||||
bool requires_broadcast = !HaveSameShapes(input1, input2);
|
||||
switch (input1->type) {
|
||||
case kTfLiteFloat32:
|
||||
@ -268,9 +292,15 @@ TfLiteStatus GreaterEval(TfLiteContext* context, TfLiteNode* node) {
|
||||
}
|
||||
|
||||
TfLiteStatus GreaterEqualEval(TfLiteContext* context, TfLiteNode* node) {
|
||||
const TfLiteTensor* input1 = GetInput(context, node, kInputTensor1);
|
||||
const TfLiteTensor* input2 = GetInput(context, node, kInputTensor2);
|
||||
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
|
||||
const TfLiteTensor* input1;
|
||||
TF_LITE_ENSURE_OK(context,
|
||||
GetInputSafe(context, node, kInputTensor1, &input1));
|
||||
const TfLiteTensor* input2;
|
||||
TF_LITE_ENSURE_OK(context,
|
||||
GetInputSafe(context, node, kInputTensor2, &input2));
|
||||
TfLiteTensor* output;
|
||||
TF_LITE_ENSURE_OK(context,
|
||||
GetOutputSafe(context, node, kOutputTensor, &output));
|
||||
bool requires_broadcast = !HaveSameShapes(input1, input2);
|
||||
switch (input1->type) {
|
||||
case kTfLiteFloat32:
|
||||
@ -303,9 +333,15 @@ TfLiteStatus GreaterEqualEval(TfLiteContext* context, TfLiteNode* node) {
|
||||
}
|
||||
|
||||
TfLiteStatus LessEval(TfLiteContext* context, TfLiteNode* node) {
|
||||
const TfLiteTensor* input1 = GetInput(context, node, kInputTensor1);
|
||||
const TfLiteTensor* input2 = GetInput(context, node, kInputTensor2);
|
||||
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
|
||||
const TfLiteTensor* input1;
|
||||
TF_LITE_ENSURE_OK(context,
|
||||
GetInputSafe(context, node, kInputTensor1, &input1));
|
||||
const TfLiteTensor* input2;
|
||||
TF_LITE_ENSURE_OK(context,
|
||||
GetInputSafe(context, node, kInputTensor2, &input2));
|
||||
TfLiteTensor* output;
|
||||
TF_LITE_ENSURE_OK(context,
|
||||
GetOutputSafe(context, node, kOutputTensor, &output));
|
||||
bool requires_broadcast = !HaveSameShapes(input1, input2);
|
||||
switch (input1->type) {
|
||||
case kTfLiteFloat32:
|
||||
@ -338,9 +374,15 @@ TfLiteStatus LessEval(TfLiteContext* context, TfLiteNode* node) {
|
||||
}
|
||||
|
||||
TfLiteStatus LessEqualEval(TfLiteContext* context, TfLiteNode* node) {
|
||||
const TfLiteTensor* input1 = GetInput(context, node, kInputTensor1);
|
||||
const TfLiteTensor* input2 = GetInput(context, node, kInputTensor2);
|
||||
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
|
||||
const TfLiteTensor* input1;
|
||||
TF_LITE_ENSURE_OK(context,
|
||||
GetInputSafe(context, node, kInputTensor1, &input1));
|
||||
const TfLiteTensor* input2;
|
||||
TF_LITE_ENSURE_OK(context,
|
||||
GetInputSafe(context, node, kInputTensor2, &input2));
|
||||
TfLiteTensor* output;
|
||||
TF_LITE_ENSURE_OK(context,
|
||||
GetOutputSafe(context, node, kOutputTensor, &output));
|
||||
bool requires_broadcast = !HaveSameShapes(input1, input2);
|
||||
switch (input1->type) {
|
||||
case kTfLiteFloat32:
|
||||
|
@ -45,7 +45,8 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
|
||||
// The number of dimensions of the input tensors must match, and all
|
||||
// dimensions except 'axis' must be equal.
|
||||
const TfLiteTensor* t0 = GetInput(context, node, 0);
|
||||
const TfLiteTensor* t0;
|
||||
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 0, &t0));
|
||||
TfLiteType input_type = t0->type;
|
||||
if (axis < 0) axis += t0->dims->size;
|
||||
TF_LITE_ENSURE(context, axis >= 0);
|
||||
@ -63,7 +64,8 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
// will be the sum of inputs
|
||||
int sum_axis = t0->dims->data[axis];
|
||||
for (int i = 1; i < num_inputs; ++i) {
|
||||
const TfLiteTensor* t = GetInput(context, node, i);
|
||||
const TfLiteTensor* t;
|
||||
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, i, &t));
|
||||
TF_LITE_ENSURE_EQ(context, t->dims->size, t0->dims->size);
|
||||
TF_LITE_ENSURE_EQ(context, t->type, input_type);
|
||||
for (int d = 0; d < t0->dims->size; ++d) {
|
||||
@ -80,7 +82,8 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
output_size->data[d] = (d == axis) ? sum_axis : t0->dims->data[d];
|
||||
}
|
||||
|
||||
TfLiteTensor* output = GetOutput(context, node, 0);
|
||||
TfLiteTensor* output;
|
||||
TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, 0, &output));
|
||||
TF_LITE_ENSURE_TYPES_EQ(context, output->type, input_type);
|
||||
|
||||
if (input_type == kTfLiteInt8) {
|
||||
@ -88,7 +91,8 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
// is a restriction we introduced to Int8 kernels.
|
||||
VectorOfTensors<int8_t> all_inputs(*context, *node->inputs);
|
||||
for (int i = 0; i < node->inputs->size; ++i) {
|
||||
const TfLiteTensor* t = GetInput(context, node, i);
|
||||
const TfLiteTensor* t;
|
||||
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, i, &t));
|
||||
TF_LITE_ENSURE_EQ(context, t->params.scale, output->params.scale);
|
||||
TF_LITE_ENSURE_EQ(context, t->params.zero_point,
|
||||
output->params.zero_point);
|
||||
@ -103,7 +107,8 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
||||
auto* params =
|
||||
reinterpret_cast<TfLiteConcatenationParams*>(node->builtin_data);
|
||||
int axis = params->axis;
|
||||
TfLiteTensor* output = GetOutput(context, node, 0);
|
||||
TfLiteTensor* output;
|
||||
TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, 0, &output));
|
||||
if (axis < 0) axis += output->dims->size;
|
||||
|
||||
// TODO(ahentz): Creating 'all_inputs' below is not very efficient. We should
|
||||
|
@ -222,8 +222,10 @@ static TfLiteStatus AllocateTemporaryTensorsIfRequired(TfLiteContext* context,
|
||||
OpData* data = reinterpret_cast<OpData*>(node->user_data);
|
||||
|
||||
TF_LITE_ENSURE(context, node->inputs->size >= 2);
|
||||
const TfLiteTensor* input = GetInput(context, node, 0);
|
||||
const TfLiteTensor* filter = GetInput(context, node, 1);
|
||||
const TfLiteTensor* input;
|
||||
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 0, &input));
|
||||
const TfLiteTensor* filter;
|
||||
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 1, &filter));
|
||||
|
||||
// If we're using the optimized multithreaded EigenTensor implementation of
|
||||
// convolution, it expects the filter weights to be transposed compared to
|
||||
@ -316,9 +318,12 @@ TfLiteStatus Prepare(KernelType kernel_type, TfLiteContext* context,
|
||||
// Check number of inputs/outputs
|
||||
TF_LITE_ENSURE(context, has_bias || node->inputs->size == 2);
|
||||
TF_LITE_ENSURE_EQ(context, node->outputs->size, 1);
|
||||
TfLiteTensor* output = GetOutput(context, node, 0);
|
||||
const TfLiteTensor* input = GetInput(context, node, 0);
|
||||
const TfLiteTensor* filter = GetInput(context, node, 1);
|
||||
TfLiteTensor* output;
|
||||
TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, 0, &output));
|
||||
const TfLiteTensor* input;
|
||||
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 0, &input));
|
||||
const TfLiteTensor* filter;
|
||||
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 1, &filter));
|
||||
|
||||
// Check dimensionality of input, filter
|
||||
TF_LITE_ENSURE_EQ(context, input->dims->size, 4);
|
||||
@ -340,7 +345,7 @@ TfLiteStatus Prepare(KernelType kernel_type, TfLiteContext* context,
|
||||
TF_LITE_ENSURE(context, has_bias);
|
||||
|
||||
if (has_bias) {
|
||||
bias = GetInput(context, node, 2);
|
||||
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 2, &bias));
|
||||
if (input_type == kTfLiteUInt8 || input_type == kTfLiteInt8) {
|
||||
TF_LITE_ENSURE_TYPES_EQ(context, bias->type, kTfLiteInt32);
|
||||
TF_LITE_ENSURE_EQ(context, bias->params.zero_point, 0);
|
||||
@ -493,8 +498,10 @@ TfLiteStatus Prepare(KernelType kernel_type, TfLiteContext* context,
|
||||
if (is_hybrid) {
|
||||
node->temporaries->data[data->input_quantized_index] =
|
||||
data->input_quantized_id;
|
||||
TfLiteTensor* input_quantized =
|
||||
GetTemporary(context, node, data->input_quantized_index);
|
||||
TfLiteTensor* input_quantized;
|
||||
TF_LITE_ENSURE_OK(
|
||||
context, GetTemporarySafe(context, node, data->input_quantized_index,
|
||||
&input_quantized));
|
||||
input_quantized->type = kTfLiteInt8;
|
||||
input_quantized->allocation_type = kTfLiteArenaRw;
|
||||
if (!TfLiteIntArrayEqual(input_quantized->dims, input->dims)) {
|
||||
@ -505,8 +512,10 @@ TfLiteStatus Prepare(KernelType kernel_type, TfLiteContext* context,
|
||||
|
||||
node->temporaries->data[data->scaling_factors_index] =
|
||||
data->scaling_factors_id;
|
||||
TfLiteTensor* scaling_factors =
|
||||
GetTemporary(context, node, data->scaling_factors_index);
|
||||
TfLiteTensor* scaling_factors;
|
||||
TF_LITE_ENSURE_OK(
|
||||
context, GetTemporarySafe(context, node, data->scaling_factors_index,
|
||||
&scaling_factors));
|
||||
scaling_factors->type = kTfLiteFloat32;
|
||||
scaling_factors->allocation_type = kTfLiteArenaRw;
|
||||
// Only one scale factor per batch is typically necessary. See optimized
|
||||
@ -522,8 +531,10 @@ TfLiteStatus Prepare(KernelType kernel_type, TfLiteContext* context,
|
||||
}
|
||||
|
||||
node->temporaries->data[data->accum_scratch_index] = data->accum_scratch_id;
|
||||
TfLiteTensor* accum_scratch =
|
||||
GetTemporary(context, node, data->accum_scratch_index);
|
||||
TfLiteTensor* accum_scratch;
|
||||
TF_LITE_ENSURE_OK(context,
|
||||
GetTemporarySafe(context, node, data->accum_scratch_index,
|
||||
&accum_scratch));
|
||||
accum_scratch->type = kTfLiteInt32;
|
||||
accum_scratch->allocation_type = kTfLiteArenaRw;
|
||||
const int scratch_width = batches * out_height * out_width;
|
||||
@ -545,8 +556,10 @@ TfLiteStatus Prepare(KernelType kernel_type, TfLiteContext* context,
|
||||
context, affine_quantization->scale->size,
|
||||
filter->dims->data[affine_quantization->quantized_dimension]);
|
||||
node->temporaries->data[data->input_offset_index] = data->input_offset_id;
|
||||
TfLiteTensor* input_offsets =
|
||||
GetTemporary(context, node, data->input_offset_index);
|
||||
TfLiteTensor* input_offsets;
|
||||
TF_LITE_ENSURE_OK(
|
||||
context, GetTemporarySafe(context, node, data->input_offset_index,
|
||||
&input_offsets));
|
||||
input_offsets->type = kTfLiteInt32;
|
||||
input_offsets->allocation_type = kTfLiteArenaRw;
|
||||
// See above comment for the need to allocate for height of inputs.
|
||||
@ -560,8 +573,10 @@ TfLiteStatus Prepare(KernelType kernel_type, TfLiteContext* context,
|
||||
input_offsets_size));
|
||||
}
|
||||
node->temporaries->data[data->row_sums_index] = data->row_sums_id;
|
||||
TfLiteTensor* row_sums =
|
||||
GetTemporary(context, node, data->row_sums_index);
|
||||
TfLiteTensor* row_sums;
|
||||
TF_LITE_ENSURE_OK(
|
||||
context,
|
||||
GetTemporarySafe(context, node, data->row_sums_index, &row_sums));
|
||||
row_sums->type = kTfLiteInt32;
|
||||
row_sums->allocation_type = kTfLiteArenaRwPersistent;
|
||||
// See above comment for the need to allocate for height of inputs.
|
||||
@ -802,23 +817,34 @@ void EvalFloat(TfLiteContext* context, TfLiteNode* node,
|
||||
}
|
||||
|
||||
template <KernelType kernel_type>
|
||||
void EvalHybridPerChannel(TfLiteContext* context, TfLiteNode* node,
|
||||
TfLiteConvParams* params, OpData* data,
|
||||
const TfLiteTensor* input, const TfLiteTensor* filter,
|
||||
const TfLiteTensor* bias, TfLiteTensor* im2col,
|
||||
TfLiteTensor* output) {
|
||||
TfLiteStatus EvalHybridPerChannel(TfLiteContext* context, TfLiteNode* node,
|
||||
TfLiteConvParams* params, OpData* data,
|
||||
const TfLiteTensor* input,
|
||||
const TfLiteTensor* filter,
|
||||
const TfLiteTensor* bias,
|
||||
TfLiteTensor* im2col, TfLiteTensor* output) {
|
||||
float output_activation_min, output_activation_max;
|
||||
CalculateActivationRange(params->activation, &output_activation_min,
|
||||
&output_activation_max);
|
||||
|
||||
const int input_size = NumElements(input) / SizeOfDimension(input, 0);
|
||||
const int batch_size = SizeOfDimension(input, 0);
|
||||
int8_t* quantized_input_ptr_batch = GetTensorData<int8_t>(
|
||||
GetTemporary(context, node, data->input_quantized_index));
|
||||
float* scaling_factors_ptr = GetTensorData<float>(
|
||||
GetTemporary(context, node, data->scaling_factors_index));
|
||||
int32_t* input_offset_ptr = GetTensorData<int32_t>(
|
||||
GetTemporary(context, node, data->input_offset_index));
|
||||
TfLiteTensor* quantized_input_tensor;
|
||||
TF_LITE_ENSURE_OK(context,
|
||||
GetTemporarySafe(context, node, data->input_quantized_index,
|
||||
&quantized_input_tensor));
|
||||
int8_t* quantized_input_ptr_batch =
|
||||
GetTensorData<int8_t>(quantized_input_tensor);
|
||||
TfLiteTensor* scaling_factors_tensor;
|
||||
TF_LITE_ENSURE_OK(context,
|
||||
GetTemporarySafe(context, node, data->scaling_factors_index,
|
||||
&scaling_factors_tensor));
|
||||
float* scaling_factors_ptr = GetTensorData<float>(scaling_factors_tensor);
|
||||
TfLiteTensor* input_offset_tensor;
|
||||
TF_LITE_ENSURE_OK(context,
|
||||
GetTemporarySafe(context, node, data->input_offset_index,
|
||||
&input_offset_tensor));
|
||||
int32_t* input_offset_ptr = GetTensorData<int32_t>(input_offset_tensor);
|
||||
|
||||
for (int b = 0; b < batch_size; ++b) {
|
||||
const int offset = b * input_size;
|
||||
@ -859,10 +885,14 @@ void EvalHybridPerChannel(TfLiteContext* context, TfLiteNode* node,
|
||||
case kGenericOptimized:
|
||||
case kMultithreadOptimized:
|
||||
case kCblasOptimized: {
|
||||
TfLiteTensor* row_sums =
|
||||
GetTemporary(context, node, data->row_sums_index);
|
||||
TfLiteTensor* scratch =
|
||||
GetTemporary(context, node, data->accum_scratch_index);
|
||||
TfLiteTensor* row_sums;
|
||||
TF_LITE_ENSURE_OK(
|
||||
context,
|
||||
GetTemporarySafe(context, node, data->row_sums_index, &row_sums));
|
||||
TfLiteTensor* scratch;
|
||||
TF_LITE_ENSURE_OK(
|
||||
context,
|
||||
GetTemporarySafe(context, node, data->accum_scratch_index, &scratch));
|
||||
optimized_ops::HybridConvPerChannel(
|
||||
op_params, scaling_factors_ptr, GetTensorShape(input),
|
||||
quantized_input_ptr_batch, GetTensorShape(filter), filter_ptr,
|
||||
@ -877,14 +907,16 @@ void EvalHybridPerChannel(TfLiteContext* context, TfLiteNode* node,
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
return kTfLiteOk;
|
||||
}
|
||||
|
||||
template <KernelType kernel_type>
|
||||
void EvalHybrid(TfLiteContext* context, TfLiteNode* node,
|
||||
TfLiteConvParams* params, OpData* data,
|
||||
const TfLiteTensor* input, const TfLiteTensor* filter,
|
||||
const TfLiteTensor* bias, TfLiteTensor* im2col,
|
||||
TfLiteTensor* accum_scratch, TfLiteTensor* output) {
|
||||
TfLiteStatus EvalHybrid(TfLiteContext* context, TfLiteNode* node,
|
||||
TfLiteConvParams* params, OpData* data,
|
||||
const TfLiteTensor* input, const TfLiteTensor* filter,
|
||||
const TfLiteTensor* bias, TfLiteTensor* im2col,
|
||||
TfLiteTensor* accum_scratch, TfLiteTensor* output) {
|
||||
float output_activation_min, output_activation_max;
|
||||
CalculateActivationRange(params->activation, &output_activation_min,
|
||||
&output_activation_max);
|
||||
@ -893,10 +925,17 @@ void EvalHybrid(TfLiteContext* context, TfLiteNode* node,
|
||||
const int batch_size = SizeOfDimension(input, 0);
|
||||
|
||||
const float* input_ptr = GetTensorData<float>(input);
|
||||
int8_t* quantized_input_ptr_batch = GetTensorData<int8_t>(
|
||||
GetTemporary(context, node, data->input_quantized_index));
|
||||
float* scaling_factors_ptr = GetTensorData<float>(
|
||||
GetTemporary(context, node, data->scaling_factors_index));
|
||||
TfLiteTensor* quantized_input_tensor;
|
||||
TF_LITE_ENSURE_OK(context,
|
||||
GetTemporarySafe(context, node, data->input_quantized_index,
|
||||
&quantized_input_tensor));
|
||||
int8_t* quantized_input_ptr_batch =
|
||||
GetTensorData<int8_t>(quantized_input_tensor);
|
||||
TfLiteTensor* scaling_factors_tensor;
|
||||
TF_LITE_ENSURE_OK(context,
|
||||
GetTemporarySafe(context, node, data->scaling_factors_index,
|
||||
&scaling_factors_tensor));
|
||||
float* scaling_factors_ptr = GetTensorData<float>(scaling_factors_tensor);
|
||||
|
||||
// Per-batch input quantization for higher accuracy.
|
||||
{
|
||||
@ -939,6 +978,8 @@ void EvalHybrid(TfLiteContext* context, TfLiteNode* node,
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
return kTfLiteOk;
|
||||
}
|
||||
|
||||
template <KernelType kernel_type, TfLiteType input_type>
|
||||
@ -946,9 +987,12 @@ TfLiteStatus EvalImpl(TfLiteContext* context, TfLiteNode* node) {
|
||||
auto* params = reinterpret_cast<TfLiteConvParams*>(node->builtin_data);
|
||||
OpData* data = reinterpret_cast<OpData*>(node->user_data);
|
||||
|
||||
TfLiteTensor* output = GetOutput(context, node, 0);
|
||||
const TfLiteTensor* input = GetInput(context, node, 0);
|
||||
const TfLiteTensor* filter = GetInput(context, node, 1);
|
||||
TfLiteTensor* output;
|
||||
TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, 0, &output));
|
||||
const TfLiteTensor* input;
|
||||
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 0, &input));
|
||||
const TfLiteTensor* filter;
|
||||
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 1, &filter));
|
||||
bool has_bias = node->inputs->size == 3;
|
||||
const TfLiteTensor* bias = has_bias ? GetInput(context, node, 2) : nullptr;
|
||||
TfLiteTensor* im2col =
|
||||
@ -970,14 +1014,17 @@ TfLiteStatus EvalImpl(TfLiteContext* context, TfLiteNode* node) {
|
||||
case kTfLiteFloat32:
|
||||
if (filter->type == kTfLiteUInt8 || filter->type == kTfLiteInt8) {
|
||||
if (data->is_hybrid_per_channel) {
|
||||
EvalHybridPerChannel<kernel_type>(context, node, params, data, input,
|
||||
filter, bias, im2col, output);
|
||||
TF_LITE_ENSURE_OK(context, EvalHybridPerChannel<kernel_type>(
|
||||
context, node, params, data, input,
|
||||
filter, bias, im2col, output));
|
||||
} else {
|
||||
TfLiteTensor* accum_scratch =
|
||||
&context->tensors[node->temporaries
|
||||
->data[data->accum_scratch_index]];
|
||||
EvalHybrid<kernel_type>(context, node, params, data, input, filter,
|
||||
bias, im2col, accum_scratch, output);
|
||||
TF_LITE_ENSURE_OK(context,
|
||||
EvalHybrid<kernel_type>(context, node, params, data,
|
||||
input, filter, bias, im2col,
|
||||
accum_scratch, output));
|
||||
}
|
||||
} else {
|
||||
EvalFloat<kernel_type>(context, node, params, data, input, filter, bias,
|
||||
@ -1006,7 +1053,8 @@ TfLiteStatus EvalImpl(TfLiteContext* context, TfLiteNode* node) {
|
||||
|
||||
template <KernelType kernel_type>
|
||||
TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
||||
const TfLiteTensor* input = GetInput(context, node, 0);
|
||||
const TfLiteTensor* input;
|
||||
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 0, &input));
|
||||
|
||||
switch (input->type) {
|
||||
case kTfLiteFloat32:
|
||||
|
@ -45,8 +45,11 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
TF_LITE_ENSURE_EQ(context, NumInputs(node), 1);
|
||||
TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
|
||||
|
||||
const TfLiteTensor* input = GetInput(context, node, kInputTensor);
|
||||
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
|
||||
const TfLiteTensor* input;
|
||||
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputTensor, &input));
|
||||
TfLiteTensor* output;
|
||||
TF_LITE_ENSURE_OK(context,
|
||||
GetOutputSafe(context, node, kOutputTensor, &output));
|
||||
|
||||
TF_LITE_ENSURE_EQ(context, NumDimensions(input), 4);
|
||||
|
||||
@ -84,8 +87,11 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
||||
auto* params =
|
||||
reinterpret_cast<TfLiteDepthToSpaceParams*>(node->builtin_data);
|
||||
|
||||
const TfLiteTensor* input = GetInput(context, node, kInputTensor);
|
||||
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
|
||||
const TfLiteTensor* input;
|
||||
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputTensor, &input));
|
||||
TfLiteTensor* output;
|
||||
TF_LITE_ENSURE_OK(context,
|
||||
GetOutputSafe(context, node, kOutputTensor, &output));
|
||||
|
||||
#define TF_LITE_DEPTH_TO_SPACE(type, scalar) \
|
||||
tflite::DepthToSpaceParams op_params; \
|
||||
|
@ -104,12 +104,17 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
bool hasBias = NumInputs(node) == 3;
|
||||
|
||||
TF_LITE_ENSURE(context, hasBias || NumInputs(node) == 2);
|
||||
const TfLiteTensor* input = GetInput(context, node, kInputTensor);
|
||||
const TfLiteTensor* filter = GetInput(context, node, kFilterTensor);
|
||||
const TfLiteTensor* input;
|
||||
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputTensor, &input));
|
||||
const TfLiteTensor* filter;
|
||||
TF_LITE_ENSURE_OK(context,
|
||||
GetInputSafe(context, node, kFilterTensor, &filter));
|
||||
const TfLiteTensor* bias = nullptr;
|
||||
|
||||
TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
|
||||
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
|
||||
TfLiteTensor* output;
|
||||
TF_LITE_ENSURE_OK(context,
|
||||
GetOutputSafe(context, node, kOutputTensor, &output));
|
||||
|
||||
TF_LITE_ENSURE_EQ(context, NumDimensions(input), 4);
|
||||
TF_LITE_ENSURE_EQ(context, NumDimensions(filter), 4);
|
||||
@ -132,7 +137,7 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
TF_LITE_ENSURE_EQ(context, SizeOfDimension(filter, 0), 1);
|
||||
|
||||
if (hasBias) {
|
||||
bias = GetInput(context, node, kBiasTensor);
|
||||
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kBiasTensor, &bias));
|
||||
if (data_type == kTfLiteUInt8 || data_type == kTfLiteInt8) {
|
||||
TF_LITE_ENSURE_TYPES_EQ(context, bias->type, kTfLiteInt32);
|
||||
TF_LITE_ENSURE_EQ(context, bias->params.zero_point, 0);
|
||||
@ -224,8 +229,10 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
|
||||
node->temporaries->data[data->input_quantized_index] =
|
||||
data->input_quantized_id;
|
||||
TfLiteTensor* input_quantized =
|
||||
GetTemporary(context, node, data->input_quantized_index);
|
||||
TfLiteTensor* input_quantized;
|
||||
TF_LITE_ENSURE_OK(
|
||||
context, GetTemporarySafe(context, node, data->input_quantized_index,
|
||||
&input_quantized));
|
||||
input_quantized->type = kTfLiteInt8;
|
||||
input_quantized->allocation_type = kTfLiteArenaRw;
|
||||
if (!TfLiteIntArrayEqual(input_quantized->dims, input->dims)) {
|
||||
@ -235,8 +242,10 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
}
|
||||
node->temporaries->data[data->scaling_factors_index] =
|
||||
data->scaling_factors_id;
|
||||
TfLiteTensor* scaling_factors =
|
||||
GetTemporary(context, node, data->scaling_factors_index);
|
||||
TfLiteTensor* scaling_factors;
|
||||
TF_LITE_ENSURE_OK(
|
||||
context, GetTemporarySafe(context, node, data->scaling_factors_index,
|
||||
&scaling_factors));
|
||||
scaling_factors->type = kTfLiteFloat32;
|
||||
scaling_factors->allocation_type = kTfLiteArenaRw;
|
||||
const int batch_size = SizeOfDimension(input, 0);
|
||||
@ -248,8 +257,10 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
scaling_factors_size));
|
||||
}
|
||||
node->temporaries->data[data->input_offset_index] = data->input_offset_id;
|
||||
TfLiteTensor* input_offsets =
|
||||
GetTemporary(context, node, data->input_offset_index);
|
||||
TfLiteTensor* input_offsets;
|
||||
TF_LITE_ENSURE_OK(context,
|
||||
GetTemporarySafe(context, node, data->input_offset_index,
|
||||
&input_offsets));
|
||||
input_offsets->type = kTfLiteInt32;
|
||||
input_offsets->allocation_type = kTfLiteArenaRw;
|
||||
if (!TfLiteIntArrayEqualsArray(input_offsets->dims, 1, scaling_dims)) {
|
||||
@ -446,13 +457,21 @@ TfLiteStatus EvalHybridPerChannel(TfLiteContext* context, TfLiteNode* node,
|
||||
&output_activation_max);
|
||||
const int input_size = NumElements(input) / SizeOfDimension(input, 0);
|
||||
const int batch_size = SizeOfDimension(input, 0);
|
||||
const TfLiteTensor* input_quantized =
|
||||
GetTemporary(context, node, data->input_quantized_index);
|
||||
TfLiteTensor* input_quantized;
|
||||
TF_LITE_ENSURE_OK(context,
|
||||
GetTemporarySafe(context, node, data->input_quantized_index,
|
||||
&input_quantized));
|
||||
int8_t* quantized_input_ptr_batch = input_quantized->data.int8;
|
||||
float* scaling_factors_ptr = GetTensorData<float>(
|
||||
GetTemporary(context, node, data->scaling_factors_index));
|
||||
int32_t* input_offset_ptr = GetTensorData<int32_t>(
|
||||
GetTemporary(context, node, data->input_offset_index));
|
||||
TfLiteTensor* scaling_factors_tensor;
|
||||
TF_LITE_ENSURE_OK(context,
|
||||
GetTemporarySafe(context, node, data->scaling_factors_index,
|
||||
&scaling_factors_tensor));
|
||||
float* scaling_factors_ptr = GetTensorData<float>(scaling_factors_tensor);
|
||||
TfLiteTensor* input_offset_tensor;
|
||||
TF_LITE_ENSURE_OK(context,
|
||||
GetTemporarySafe(context, node, data->input_offset_index,
|
||||
&input_offset_tensor));
|
||||
int32_t* input_offset_ptr = GetTensorData<int32_t>(input_offset_tensor);
|
||||
|
||||
for (int b = 0; b < batch_size; ++b) {
|
||||
const int offset = b * input_size;
|
||||
@ -504,9 +523,14 @@ TfLiteStatus EvalImpl(TfLiteContext* context, TfLiteNode* node) {
|
||||
reinterpret_cast<TfLiteDepthwiseConvParams*>(node->builtin_data);
|
||||
OpData* data = reinterpret_cast<OpData*>(node->user_data);
|
||||
|
||||
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
|
||||
const TfLiteTensor* input = GetInput(context, node, kInputTensor);
|
||||
const TfLiteTensor* filter = GetInput(context, node, kFilterTensor);
|
||||
TfLiteTensor* output;
|
||||
TF_LITE_ENSURE_OK(context,
|
||||
GetOutputSafe(context, node, kOutputTensor, &output));
|
||||
const TfLiteTensor* input;
|
||||
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputTensor, &input));
|
||||
const TfLiteTensor* filter;
|
||||
TF_LITE_ENSURE_OK(context,
|
||||
GetInputSafe(context, node, kFilterTensor, &filter));
|
||||
const TfLiteTensor* bias =
|
||||
(NumInputs(node) == 3) ? GetInput(context, node, kBiasTensor) : nullptr;
|
||||
TFLITE_DCHECK_EQ(input_type, input->type);
|
||||
@ -547,7 +571,8 @@ TfLiteStatus EvalImpl(TfLiteContext* context, TfLiteNode* node) {
|
||||
|
||||
template <KernelType kernel_type>
|
||||
TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
||||
const TfLiteTensor* input = GetInput(context, node, kInputTensor);
|
||||
const TfLiteTensor* input;
|
||||
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputTensor, &input));
|
||||
|
||||
switch (input->type) { // Already know in/out types are same.
|
||||
case kTfLiteFloat32:
|
||||
|
@ -146,12 +146,17 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
auto* op_data = static_cast<OpData*>(node->user_data);
|
||||
// Inputs: box_encodings, scores, anchors
|
||||
TF_LITE_ENSURE_EQ(context, NumInputs(node), 3);
|
||||
const TfLiteTensor* input_box_encodings =
|
||||
GetInput(context, node, kInputTensorBoxEncodings);
|
||||
const TfLiteTensor* input_class_predictions =
|
||||
GetInput(context, node, kInputTensorClassPredictions);
|
||||
const TfLiteTensor* input_anchors =
|
||||
GetInput(context, node, kInputTensorAnchors);
|
||||
const TfLiteTensor* input_box_encodings;
|
||||
TF_LITE_ENSURE_OK(context,
|
||||
GetInputSafe(context, node, kInputTensorBoxEncodings,
|
||||
&input_box_encodings));
|
||||
const TfLiteTensor* input_class_predictions;
|
||||
TF_LITE_ENSURE_OK(context,
|
||||
GetInputSafe(context, node, kInputTensorClassPredictions,
|
||||
&input_class_predictions));
|
||||
const TfLiteTensor* input_anchors;
|
||||
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputTensorAnchors,
|
||||
&input_anchors));
|
||||
TF_LITE_ENSURE_EQ(context, NumDimensions(input_box_encodings), 3);
|
||||
TF_LITE_ENSURE_EQ(context, NumDimensions(input_class_predictions), 3);
|
||||
TF_LITE_ENSURE_EQ(context, NumDimensions(input_anchors), 2);
|
||||
@ -163,27 +168,35 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
// num_detections
|
||||
TF_LITE_ENSURE_EQ(context, NumOutputs(node), 4);
|
||||
// Output Tensor detection_boxes: size is set to (1, num_detected_boxes, 4)
|
||||
TfLiteTensor* detection_boxes =
|
||||
GetOutput(context, node, kOutputTensorDetectionBoxes);
|
||||
TfLiteTensor* detection_boxes;
|
||||
TF_LITE_ENSURE_OK(context,
|
||||
GetOutputSafe(context, node, kOutputTensorDetectionBoxes,
|
||||
&detection_boxes));
|
||||
detection_boxes->type = kTfLiteFloat32;
|
||||
SetTensorSizes(context, detection_boxes,
|
||||
{kBatchSize, num_detected_boxes, kNumCoordBox});
|
||||
|
||||
// Output Tensor detection_classes: size is set to (1, num_detected_boxes)
|
||||
TfLiteTensor* detection_classes =
|
||||
GetOutput(context, node, kOutputTensorDetectionClasses);
|
||||
TfLiteTensor* detection_classes;
|
||||
TF_LITE_ENSURE_OK(context,
|
||||
GetOutputSafe(context, node, kOutputTensorDetectionClasses,
|
||||
&detection_classes));
|
||||
detection_classes->type = kTfLiteFloat32;
|
||||
SetTensorSizes(context, detection_classes, {kBatchSize, num_detected_boxes});
|
||||
|
||||
// Output Tensor detection_scores: size is set to (1, num_detected_boxes)
|
||||
TfLiteTensor* detection_scores =
|
||||
GetOutput(context, node, kOutputTensorDetectionScores);
|
||||
TfLiteTensor* detection_scores;
|
||||
TF_LITE_ENSURE_OK(context,
|
||||
GetOutputSafe(context, node, kOutputTensorDetectionScores,
|
||||
&detection_scores));
|
||||
detection_scores->type = kTfLiteFloat32;
|
||||
SetTensorSizes(context, detection_scores, {kBatchSize, num_detected_boxes});
|
||||
|
||||
// Output Tensor num_detections: size is set to 1
|
||||
TfLiteTensor* num_detections =
|
||||
GetOutput(context, node, kOutputTensorNumDetections);
|
||||
TfLiteTensor* num_detections;
|
||||
TF_LITE_ENSURE_OK(context,
|
||||
GetOutputSafe(context, node, kOutputTensorNumDetections,
|
||||
&num_detections));
|
||||
num_detections->type = kTfLiteFloat32;
|
||||
// TODO (chowdhery): Make it a scalar when available
|
||||
SetTensorSizes(context, num_detections, {1});
|
||||
@ -269,13 +282,16 @@ T ReInterpretTensor(TfLiteTensor* tensor) {
|
||||
TfLiteStatus DecodeCenterSizeBoxes(TfLiteContext* context, TfLiteNode* node,
|
||||
OpData* op_data) {
|
||||
// Parse input tensor boxencodings
|
||||
const TfLiteTensor* input_box_encodings =
|
||||
GetInput(context, node, kInputTensorBoxEncodings);
|
||||
const TfLiteTensor* input_box_encodings;
|
||||
TF_LITE_ENSURE_OK(context,
|
||||
GetInputSafe(context, node, kInputTensorBoxEncodings,
|
||||
&input_box_encodings));
|
||||
TF_LITE_ENSURE_EQ(context, input_box_encodings->dims->data[0], kBatchSize);
|
||||
const int num_boxes = input_box_encodings->dims->data[1];
|
||||
TF_LITE_ENSURE(context, input_box_encodings->dims->data[2] >= kNumCoordBox);
|
||||
const TfLiteTensor* input_anchors =
|
||||
GetInput(context, node, kInputTensorAnchors);
|
||||
const TfLiteTensor* input_anchors;
|
||||
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputTensorAnchors,
|
||||
&input_anchors));
|
||||
|
||||
// Decode the boxes to get (ymin, xmin, ymax, xmax) based on the anchors
|
||||
CenterSizeEncoding box_centersize;
|
||||
@ -389,8 +405,10 @@ TfLiteStatus NonMaxSuppressionSingleClassHelper(
|
||||
TfLiteContext* context, TfLiteNode* node, OpData* op_data,
|
||||
const std::vector<float>& scores, std::vector<int>* selected,
|
||||
int max_detections) {
|
||||
const TfLiteTensor* input_box_encodings =
|
||||
GetInput(context, node, kInputTensorBoxEncodings);
|
||||
const TfLiteTensor* input_box_encodings;
|
||||
TF_LITE_ENSURE_OK(context,
|
||||
GetInputSafe(context, node, kInputTensorBoxEncodings,
|
||||
&input_box_encodings));
|
||||
const TfLiteTensor* decoded_boxes =
|
||||
&context->tensors[op_data->decoded_boxes_index];
|
||||
const int num_boxes = input_box_encodings->dims->data[1];
|
||||
@ -468,21 +486,33 @@ TfLiteStatus NonMaxSuppressionMultiClassRegularHelper(TfLiteContext* context,
|
||||
TfLiteNode* node,
|
||||
OpData* op_data,
|
||||
const float* scores) {
|
||||
const TfLiteTensor* input_box_encodings =
|
||||
GetInput(context, node, kInputTensorBoxEncodings);
|
||||
const TfLiteTensor* input_class_predictions =
|
||||
GetInput(context, node, kInputTensorClassPredictions);
|
||||
const TfLiteTensor* input_box_encodings;
|
||||
TF_LITE_ENSURE_OK(context,
|
||||
GetInputSafe(context, node, kInputTensorBoxEncodings,
|
||||
&input_box_encodings));
|
||||
const TfLiteTensor* input_class_predictions;
|
||||
TF_LITE_ENSURE_OK(context,
|
||||
GetInputSafe(context, node, kInputTensorClassPredictions,
|
||||
&input_class_predictions));
|
||||
const TfLiteTensor* decoded_boxes =
|
||||
&context->tensors[op_data->decoded_boxes_index];
|
||||
|
||||
TfLiteTensor* detection_boxes =
|
||||
GetOutput(context, node, kOutputTensorDetectionBoxes);
|
||||
TfLiteTensor* detection_classes =
|
||||
GetOutput(context, node, kOutputTensorDetectionClasses);
|
||||
TfLiteTensor* detection_scores =
|
||||
GetOutput(context, node, kOutputTensorDetectionScores);
|
||||
TfLiteTensor* num_detections =
|
||||
GetOutput(context, node, kOutputTensorNumDetections);
|
||||
TfLiteTensor* detection_boxes;
|
||||
TF_LITE_ENSURE_OK(context,
|
||||
GetOutputSafe(context, node, kOutputTensorDetectionBoxes,
|
||||
&detection_boxes));
|
||||
TfLiteTensor* detection_classes;
|
||||
TF_LITE_ENSURE_OK(context,
|
||||
GetOutputSafe(context, node, kOutputTensorDetectionClasses,
|
||||
&detection_classes));
|
||||
TfLiteTensor* detection_scores;
|
||||
TF_LITE_ENSURE_OK(context,
|
||||
GetOutputSafe(context, node, kOutputTensorDetectionScores,
|
||||
&detection_scores));
|
||||
TfLiteTensor* num_detections;
|
||||
TF_LITE_ENSURE_OK(context,
|
||||
GetOutputSafe(context, node, kOutputTensorNumDetections,
|
||||
&num_detections));
|
||||
|
||||
const int num_boxes = input_box_encodings->dims->data[1];
|
||||
const int num_classes = op_data->num_classes;
|
||||
@ -595,21 +625,33 @@ TfLiteStatus NonMaxSuppressionMultiClassFastHelper(TfLiteContext* context,
|
||||
TfLiteNode* node,
|
||||
OpData* op_data,
|
||||
const float* scores) {
|
||||
const TfLiteTensor* input_box_encodings =
|
||||
GetInput(context, node, kInputTensorBoxEncodings);
|
||||
const TfLiteTensor* input_class_predictions =
|
||||
GetInput(context, node, kInputTensorClassPredictions);
|
||||
const TfLiteTensor* input_box_encodings;
|
||||
TF_LITE_ENSURE_OK(context,
|
||||
GetInputSafe(context, node, kInputTensorBoxEncodings,
|
||||
&input_box_encodings));
|
||||
const TfLiteTensor* input_class_predictions;
|
||||
TF_LITE_ENSURE_OK(context,
|
||||
GetInputSafe(context, node, kInputTensorClassPredictions,
|
||||
&input_class_predictions));
|
||||
const TfLiteTensor* decoded_boxes =
|
||||
&context->tensors[op_data->decoded_boxes_index];
|
||||
|
||||
TfLiteTensor* detection_boxes =
|
||||
GetOutput(context, node, kOutputTensorDetectionBoxes);
|
||||
TfLiteTensor* detection_classes =
|
||||
GetOutput(context, node, kOutputTensorDetectionClasses);
|
||||
TfLiteTensor* detection_scores =
|
||||
GetOutput(context, node, kOutputTensorDetectionScores);
|
||||
TfLiteTensor* num_detections =
|
||||
GetOutput(context, node, kOutputTensorNumDetections);
|
||||
TfLiteTensor* detection_boxes;
|
||||
TF_LITE_ENSURE_OK(context,
|
||||
GetOutputSafe(context, node, kOutputTensorDetectionBoxes,
|
||||
&detection_boxes));
|
||||
TfLiteTensor* detection_classes;
|
||||
TF_LITE_ENSURE_OK(context,
|
||||
GetOutputSafe(context, node, kOutputTensorDetectionClasses,
|
||||
&detection_classes));
|
||||
TfLiteTensor* detection_scores;
|
||||
TF_LITE_ENSURE_OK(context,
|
||||
GetOutputSafe(context, node, kOutputTensorDetectionScores,
|
||||
&detection_scores));
|
||||
TfLiteTensor* num_detections;
|
||||
TF_LITE_ENSURE_OK(context,
|
||||
GetOutputSafe(context, node, kOutputTensorNumDetections,
|
||||
&num_detections));
|
||||
|
||||
const int num_boxes = input_box_encodings->dims->data[1];
|
||||
const int num_classes = op_data->num_classes;
|
||||
@ -680,10 +722,14 @@ void DequantizeClassPredictions(const TfLiteTensor* input_class_predictions,
|
||||
TfLiteStatus NonMaxSuppressionMultiClass(TfLiteContext* context,
|
||||
TfLiteNode* node, OpData* op_data) {
|
||||
// Get the input tensors
|
||||
const TfLiteTensor* input_box_encodings =
|
||||
GetInput(context, node, kInputTensorBoxEncodings);
|
||||
const TfLiteTensor* input_class_predictions =
|
||||
GetInput(context, node, kInputTensorClassPredictions);
|
||||
const TfLiteTensor* input_box_encodings;
|
||||
TF_LITE_ENSURE_OK(context,
|
||||
GetInputSafe(context, node, kInputTensorBoxEncodings,
|
||||
&input_box_encodings));
|
||||
const TfLiteTensor* input_class_predictions;
|
||||
TF_LITE_ENSURE_OK(context,
|
||||
GetInputSafe(context, node, kInputTensorClassPredictions,
|
||||
&input_class_predictions));
|
||||
const int num_boxes = input_box_encodings->dims->data[1];
|
||||
const int num_classes = op_data->num_classes;
|
||||
TF_LITE_ENSURE_EQ(context, input_class_predictions->dims->data[0],
|
||||
|
@ -74,9 +74,15 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
TF_LITE_ENSURE_EQ(context, NumInputs(node), 2);
|
||||
TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
|
||||
|
||||
const TfLiteTensor* input1 = GetInput(context, node, kInputTensor1);
|
||||
const TfLiteTensor* input2 = GetInput(context, node, kInputTensor2);
|
||||
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
|
||||
const TfLiteTensor* input1;
|
||||
TF_LITE_ENSURE_OK(context,
|
||||
GetInputSafe(context, node, kInputTensor1, &input1));
|
||||
const TfLiteTensor* input2;
|
||||
TF_LITE_ENSURE_OK(context,
|
||||
GetInputSafe(context, node, kInputTensor2, &input2));
|
||||
TfLiteTensor* output;
|
||||
TF_LITE_ENSURE_OK(context,
|
||||
GetOutputSafe(context, node, kOutputTensor, &output));
|
||||
|
||||
TF_LITE_ENSURE_TYPES_EQ(context, input1->type, input2->type);
|
||||
output->type = input2->type;
|
||||
@ -200,9 +206,15 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
||||
auto* params = reinterpret_cast<TfLiteDivParams*>(node->builtin_data);
|
||||
OpData* data = reinterpret_cast<OpData*>(node->user_data);
|
||||
|
||||
const TfLiteTensor* input1 = GetInput(context, node, kInputTensor1);
|
||||
const TfLiteTensor* input2 = GetInput(context, node, kInputTensor2);
|
||||
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
|
||||
const TfLiteTensor* input1;
|
||||
TF_LITE_ENSURE_OK(context,
|
||||
GetInputSafe(context, node, kInputTensor1, &input1));
|
||||
const TfLiteTensor* input2;
|
||||
TF_LITE_ENSURE_OK(context,
|
||||
GetInputSafe(context, node, kInputTensor2, &input2));
|
||||
TfLiteTensor* output;
|
||||
TF_LITE_ENSURE_OK(context,
|
||||
GetOutputSafe(context, node, kOutputTensor, &output));
|
||||
|
||||
if (output->type == kTfLiteFloat32 || output->type == kTfLiteInt32) {
|
||||
EvalDiv<kernel_type>(context, node, params, data, input1, input2, output);
|
||||
|
@ -66,8 +66,10 @@ template <IsSupportedType is_supported_type, const char* op_name>
|
||||
TfLiteStatus GenericPrepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
TF_LITE_ENSURE_EQ(context, NumInputs(node), 1);
|
||||
TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
|
||||
const TfLiteTensor* input = GetInput(context, node, 0);
|
||||
TfLiteTensor* output = GetOutput(context, node, 0);
|
||||
const TfLiteTensor* input;
|
||||
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 0, &input));
|
||||
TfLiteTensor* output;
|
||||
TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, 0, &output));
|
||||
TF_LITE_ENSURE_TYPES_EQ(context, input->type, output->type);
|
||||
if (!is_supported_type(input->type)) {
|
||||
TF_LITE_UNSUPPORTED_TYPE(context, input->type, op_name);
|
||||
@ -114,8 +116,10 @@ template <typename T>
|
||||
inline TfLiteStatus EvalImpl(TfLiteContext* context, TfLiteNode* node,
|
||||
std::function<T(T)> func,
|
||||
TfLiteType expected_type) {
|
||||
const TfLiteTensor* input = GetInput(context, node, 0);
|
||||
TfLiteTensor* output = GetOutput(context, node, 0);
|
||||
const TfLiteTensor* input;
|
||||
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 0, &input));
|
||||
TfLiteTensor* output;
|
||||
TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, 0, &output));
|
||||
TF_LITE_ENSURE_TYPES_EQ(context, input->type, expected_type);
|
||||
const int64_t num_elements = NumElements(input);
|
||||
const T* in_data = GetTensorData<T>(input);
|
||||
|
@ -46,14 +46,17 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
TF_LITE_ENSURE_EQ(context, NumInputs(node), 2);
|
||||
TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
|
||||
|
||||
const TfLiteTensor* lookup = GetInput(context, node, 0);
|
||||
const TfLiteTensor* lookup;
|
||||
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 0, &lookup));
|
||||
TF_LITE_ENSURE_EQ(context, NumDimensions(lookup), 1);
|
||||
TF_LITE_ENSURE_EQ(context, lookup->type, kTfLiteInt32);
|
||||
|
||||
const TfLiteTensor* value = GetInput(context, node, 1);
|
||||
const TfLiteTensor* value;
|
||||
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 1, &value));
|
||||
TF_LITE_ENSURE(context, NumDimensions(value) >= 2);
|
||||
|
||||
TfLiteTensor* output = GetOutput(context, node, 0);
|
||||
TfLiteTensor* output;
|
||||
TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, 0, &output));
|
||||
TfLiteIntArray* outputSize = TfLiteIntArrayCreate(NumDimensions(value));
|
||||
|
||||
outputSize->data[0] = SizeOfDimension(lookup, 0);
|
||||
@ -129,9 +132,12 @@ TfLiteStatus EvalHybrid(TfLiteContext* context, TfLiteNode* node,
|
||||
}
|
||||
|
||||
TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
||||
const TfLiteTensor* lookup = GetInput(context, node, 0);
|
||||
const TfLiteTensor* value = GetInput(context, node, 1);
|
||||
TfLiteTensor* output = GetOutput(context, node, 0);
|
||||
const TfLiteTensor* lookup;
|
||||
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 0, &lookup));
|
||||
const TfLiteTensor* value;
|
||||
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 1, &value));
|
||||
TfLiteTensor* output;
|
||||
TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, 0, &output));
|
||||
switch (value->type) {
|
||||
case kTfLiteFloat32:
|
||||
return EvalSimple(context, node, lookup, value, output);
|
||||
|
@ -83,19 +83,23 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
TF_LITE_ENSURE_EQ(context, NumInputs(node), 5);
|
||||
TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
|
||||
|
||||
const TfLiteTensor* ids = GetInput(context, node, 0);
|
||||
const TfLiteTensor* ids;
|
||||
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 0, &ids));
|
||||
TF_LITE_ENSURE_EQ(context, NumDimensions(ids), 1);
|
||||
TF_LITE_ENSURE_EQ(context, ids->type, kTfLiteInt32);
|
||||
|
||||
const TfLiteTensor* indices = GetInput(context, node, 1);
|
||||
const TfLiteTensor* indices;
|
||||
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 1, &indices));
|
||||
TF_LITE_ENSURE_EQ(context, NumDimensions(indices), 2);
|
||||
TF_LITE_ENSURE_EQ(context, indices->type, kTfLiteInt32);
|
||||
|
||||
const TfLiteTensor* shape = GetInput(context, node, 2);
|
||||
const TfLiteTensor* shape;
|
||||
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 2, &shape));
|
||||
TF_LITE_ENSURE_EQ(context, NumDimensions(shape), 1);
|
||||
TF_LITE_ENSURE_EQ(context, shape->type, kTfLiteInt32);
|
||||
|
||||
const TfLiteTensor* weights = GetInput(context, node, 3);
|
||||
const TfLiteTensor* weights;
|
||||
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 3, &weights));
|
||||
TF_LITE_ENSURE_EQ(context, NumDimensions(weights), 1);
|
||||
TF_LITE_ENSURE_EQ(context, weights->type, kTfLiteFloat32);
|
||||
|
||||
@ -104,11 +108,13 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
TF_LITE_ENSURE_EQ(context, SizeOfDimension(indices, 0),
|
||||
SizeOfDimension(weights, 0));
|
||||
|
||||
const TfLiteTensor* value = GetInput(context, node, 4);
|
||||
const TfLiteTensor* value;
|
||||
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 4, &value));
|
||||
TF_LITE_ENSURE(context, NumDimensions(value) >= 2);
|
||||
|
||||
// Mark the output as a dynamic tensor.
|
||||
TfLiteTensor* output = GetOutput(context, node, 0);
|
||||
TfLiteTensor* output;
|
||||
TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, 0, &output));
|
||||
TF_LITE_ENSURE_TYPES_EQ(context, output->type, kTfLiteFloat32);
|
||||
output->allocation_type = kTfLiteDynamic;
|
||||
|
||||
@ -140,12 +146,18 @@ void FinalizeAggregation(TfLiteCombinerType combiner, int num_elements,
|
||||
TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
||||
auto* params =
|
||||
reinterpret_cast<TfLiteEmbeddingLookupSparseParams*>(node->builtin_data);
|
||||
TfLiteTensor* output = GetOutput(context, node, 0);
|
||||
const TfLiteTensor* ids = GetInput(context, node, 0);
|
||||
const TfLiteTensor* indices = GetInput(context, node, 1);
|
||||
const TfLiteTensor* dense_shape = GetInput(context, node, 2);
|
||||
const TfLiteTensor* weights = GetInput(context, node, 3);
|
||||
const TfLiteTensor* value = GetInput(context, node, 4);
|
||||
TfLiteTensor* output;
|
||||
TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, 0, &output));
|
||||
const TfLiteTensor* ids;
|
||||
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 0, &ids));
|
||||
const TfLiteTensor* indices;
|
||||
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 1, &indices));
|
||||
const TfLiteTensor* dense_shape;
|
||||
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 2, &dense_shape));
|
||||
const TfLiteTensor* weights;
|
||||
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 3, &weights));
|
||||
const TfLiteTensor* value;
|
||||
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 4, &value));
|
||||
|
||||
const int lookup_rank = SizeOfDimension(indices, 1);
|
||||
const int embedding_rank = NumDimensions(value);
|
||||
|
@ -73,9 +73,12 @@ TfLiteStatus GetAxisValueFromTensor(TfLiteContext* context,
|
||||
TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
TF_LITE_ENSURE_EQ(context, NumInputs(node), 2);
|
||||
TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
|
||||
const TfLiteTensor* input = GetInput(context, node, kInput);
|
||||
const TfLiteTensor* axis = GetInput(context, node, kAxis);
|
||||
TfLiteTensor* output = GetOutput(context, node, 0);
|
||||
const TfLiteTensor* input;
|
||||
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInput, &input));
|
||||
const TfLiteTensor* axis;
|
||||
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kAxis, &axis));
|
||||
TfLiteTensor* output;
|
||||
TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, 0, &output));
|
||||
output->type = input->type;
|
||||
if (IsConstantTensor(axis)) {
|
||||
int axis_value;
|
||||
@ -89,9 +92,12 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
|
||||
TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
||||
// Just copy input to output.
|
||||
const TfLiteTensor* input = GetInput(context, node, kInput);
|
||||
TfLiteTensor* output = GetOutput(context, node, 0);
|
||||
const TfLiteTensor* axis = GetInput(context, node, kAxis);
|
||||
const TfLiteTensor* input;
|
||||
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInput, &input));
|
||||
TfLiteTensor* output;
|
||||
TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, 0, &output));
|
||||
const TfLiteTensor* axis;
|
||||
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kAxis, &axis));
|
||||
if (IsDynamicTensor(output)) {
|
||||
int axis_value;
|
||||
TF_LITE_ENSURE_OK(context,
|
||||
|
@ -72,8 +72,10 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
TF_LITE_ENSURE_EQ(context, NumInputs(node), 2);
|
||||
TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
|
||||
|
||||
const TfLiteTensor* dims = GetInput(context, node, kDimsTensor);
|
||||
const TfLiteTensor* value = GetInput(context, node, kValueTensor);
|
||||
const TfLiteTensor* dims;
|
||||
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kDimsTensor, &dims));
|
||||
const TfLiteTensor* value;
|
||||
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kValueTensor, &value));
|
||||
|
||||
// Make sure the 1st input tensor is 1-D.
|
||||
TF_LITE_ENSURE_EQ(context, NumDimensions(dims), 1);
|
||||
@ -85,7 +87,9 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
// Make sure the 2nd input tensor is a scalar.
|
||||
TF_LITE_ENSURE_EQ(context, NumDimensions(value), 0);
|
||||
|
||||
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
|
||||
TfLiteTensor* output;
|
||||
TF_LITE_ENSURE_OK(context,
|
||||
GetOutputSafe(context, node, kOutputTensor, &output));
|
||||
output->type = value->type;
|
||||
|
||||
if (IsConstantTensor(dims)) {
|
||||
@ -111,12 +115,16 @@ TfLiteStatus FillString(const TfLiteTensor* value, TfLiteTensor* output) {
|
||||
}
|
||||
|
||||
TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
||||
const TfLiteTensor* value = GetInput(context, node, kValueTensor);
|
||||
const TfLiteTensor* value;
|
||||
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kValueTensor, &value));
|
||||
|
||||
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
|
||||
TfLiteTensor* output;
|
||||
TF_LITE_ENSURE_OK(context,
|
||||
GetOutputSafe(context, node, kOutputTensor, &output));
|
||||
|
||||
if (IsDynamicTensor(output)) {
|
||||
const TfLiteTensor* dims = GetInput(context, node, kDimsTensor);
|
||||
const TfLiteTensor* dims;
|
||||
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kDimsTensor, &dims));
|
||||
TF_LITE_ENSURE_OK(context, ResizeOutput(context, dims, output));
|
||||
}
|
||||
#define TF_LITE_FILL(data_type) \
|
||||
|
@ -35,8 +35,11 @@ enum KernelType {
|
||||
};
|
||||
|
||||
TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
const TfLiteTensor* input = GetInput(context, node, kInputTensor);
|
||||
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
|
||||
const TfLiteTensor* input;
|
||||
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputTensor, &input));
|
||||
TfLiteTensor* output;
|
||||
TF_LITE_ENSURE_OK(context,
|
||||
GetOutputSafe(context, node, kOutputTensor, &output));
|
||||
TF_LITE_ENSURE_EQ(context, NumInputs(node), 1);
|
||||
TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
|
||||
TF_LITE_ENSURE_TYPES_EQ(context, input->type, kTfLiteFloat32);
|
||||
@ -47,8 +50,11 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
|
||||
template <KernelType type>
|
||||
TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
||||
const TfLiteTensor* input = GetInput(context, node, kInputTensor);
|
||||
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
|
||||
const TfLiteTensor* input;
|
||||
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputTensor, &input));
|
||||
TfLiteTensor* output;
|
||||
TF_LITE_ENSURE_OK(context,
|
||||
GetOutputSafe(context, node, kOutputTensor, &output));
|
||||
|
||||
if (type == kGenericOptimized) {
|
||||
optimized_ops::Floor(GetTensorShape(input), GetTensorData<float>(input),
|
||||
|
@ -64,9 +64,15 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
// Reinterprete the opaque data provided by user.
|
||||
OpData* data = reinterpret_cast<OpData*>(node->user_data);
|
||||
|
||||
const TfLiteTensor* input1 = GetInput(context, node, kInputTensor1);
|
||||
const TfLiteTensor* input2 = GetInput(context, node, kInputTensor2);
|
||||
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
|
||||
const TfLiteTensor* input1;
|
||||
TF_LITE_ENSURE_OK(context,
|
||||
GetInputSafe(context, node, kInputTensor1, &input1));
|
||||
const TfLiteTensor* input2;
|
||||
TF_LITE_ENSURE_OK(context,
|
||||
GetInputSafe(context, node, kInputTensor2, &input2));
|
||||
TfLiteTensor* output;
|
||||
TF_LITE_ENSURE_OK(context,
|
||||
GetOutputSafe(context, node, kOutputTensor, &output));
|
||||
|
||||
TF_LITE_ENSURE_TYPES_EQ(context, input1->type, input2->type);
|
||||
|
||||
@ -126,9 +132,15 @@ TfLiteStatus EvalImpl(TfLiteContext* context, bool requires_broadcast,
|
||||
TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
||||
OpData* data = reinterpret_cast<OpData*>(node->user_data);
|
||||
|
||||
const TfLiteTensor* input1 = GetInput(context, node, kInputTensor1);
|
||||
const TfLiteTensor* input2 = GetInput(context, node, kInputTensor2);
|
||||
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
|
||||
const TfLiteTensor* input1;
|
||||
TF_LITE_ENSURE_OK(context,
|
||||
GetInputSafe(context, node, kInputTensor1, &input1));
|
||||
const TfLiteTensor* input2;
|
||||
TF_LITE_ENSURE_OK(context,
|
||||
GetInputSafe(context, node, kInputTensor2, &input2));
|
||||
TfLiteTensor* output;
|
||||
TF_LITE_ENSURE_OK(context,
|
||||
GetOutputSafe(context, node, kOutputTensor, &output));
|
||||
|
||||
switch (input1->type) {
|
||||
case kTfLiteInt32: {
|
||||
|
@ -58,9 +58,15 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
// Reinterprete the opaque data provided by user.
|
||||
OpData* data = reinterpret_cast<OpData*>(node->user_data);
|
||||
|
||||
const TfLiteTensor* input1 = GetInput(context, node, kInputTensor1);
|
||||
const TfLiteTensor* input2 = GetInput(context, node, kInputTensor2);
|
||||
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
|
||||
const TfLiteTensor* input1;
|
||||
TF_LITE_ENSURE_OK(context,
|
||||
GetInputSafe(context, node, kInputTensor1, &input1));
|
||||
const TfLiteTensor* input2;
|
||||
TF_LITE_ENSURE_OK(context,
|
||||
GetInputSafe(context, node, kInputTensor2, &input2));
|
||||
TfLiteTensor* output;
|
||||
TF_LITE_ENSURE_OK(context,
|
||||
GetOutputSafe(context, node, kOutputTensor, &output));
|
||||
|
||||
TF_LITE_ENSURE_TYPES_EQ(context, input1->type, input2->type);
|
||||
|
||||
@ -120,9 +126,15 @@ TfLiteStatus EvalImpl(TfLiteContext* context, bool requires_broadcast,
|
||||
TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
||||
OpData* data = reinterpret_cast<OpData*>(node->user_data);
|
||||
|
||||
const TfLiteTensor* input1 = GetInput(context, node, kInputTensor1);
|
||||
const TfLiteTensor* input2 = GetInput(context, node, kInputTensor2);
|
||||
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
|
||||
const TfLiteTensor* input1;
|
||||
TF_LITE_ENSURE_OK(context,
|
||||
GetInputSafe(context, node, kInputTensor1, &input1));
|
||||
const TfLiteTensor* input2;
|
||||
TF_LITE_ENSURE_OK(context,
|
||||
GetInputSafe(context, node, kInputTensor2, &input2));
|
||||
TfLiteTensor* output;
|
||||
TF_LITE_ENSURE_OK(context,
|
||||
GetOutputSafe(context, node, kOutputTensor, &output));
|
||||
|
||||
switch (input1->type) {
|
||||
case kTfLiteInt32: {
|
||||
|
@ -155,13 +155,18 @@ TfLiteStatus PrepareImpl(TfLiteContext* context, TfLiteNode* node) {
|
||||
: 2;
|
||||
TF_LITE_ENSURE_EQ(context, node->outputs->size, expected_outputs_count);
|
||||
|
||||
const TfLiteTensor* input = GetInput(context, node, kInputTensor);
|
||||
const TfLiteTensor* filter = GetInput(context, node, kWeightsTensor);
|
||||
const TfLiteTensor* input;
|
||||
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputTensor, &input));
|
||||
const TfLiteTensor* filter;
|
||||
TF_LITE_ENSURE_OK(context,
|
||||
GetInputSafe(context, node, kWeightsTensor, &filter));
|
||||
const TfLiteTensor* bias =
|
||||
(node->inputs->size == 3)
|
||||
? GetOptionalInputTensor(context, node, kBiasTensor)
|
||||
: nullptr;
|
||||
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
|
||||
TfLiteTensor* output;
|
||||
TF_LITE_ENSURE_OK(context,
|
||||
GetOutputSafe(context, node, kOutputTensor, &output));
|
||||
|
||||
// Check proper datatype match among all Input Tensors
|
||||
TF_LITE_ENSURE_STATUS(
|
||||
@ -214,7 +219,9 @@ TfLiteStatus PrepareImpl(TfLiteContext* context, TfLiteNode* node) {
|
||||
node->temporaries = TfLiteIntArrayCreate(5);
|
||||
node->temporaries->data[0] = data->scratch_tensor_index;
|
||||
|
||||
TfLiteTensor* input_quantized = GetTemporary(context, node, /*index=*/0);
|
||||
TfLiteTensor* input_quantized;
|
||||
TF_LITE_ENSURE_OK(context, GetTemporarySafe(context, node, /*index=*/0,
|
||||
&input_quantized));
|
||||
input_quantized->type = filter->type;
|
||||
input_quantized->allocation_type = kTfLiteArenaRw;
|
||||
|
||||
@ -223,7 +230,9 @@ TfLiteStatus PrepareImpl(TfLiteContext* context, TfLiteNode* node) {
|
||||
input_quantized_size));
|
||||
|
||||
node->temporaries->data[1] = data->scratch_tensor_index + 1;
|
||||
TfLiteTensor* scaling_factors = GetTemporary(context, node, /*index=*/1);
|
||||
TfLiteTensor* scaling_factors;
|
||||
TF_LITE_ENSURE_OK(context, GetTemporarySafe(context, node, /*index=*/1,
|
||||
&scaling_factors));
|
||||
scaling_factors->type = kTfLiteFloat32;
|
||||
scaling_factors->allocation_type = kTfLiteArenaRw;
|
||||
|
||||
@ -236,7 +245,9 @@ TfLiteStatus PrepareImpl(TfLiteContext* context, TfLiteNode* node) {
|
||||
}
|
||||
|
||||
node->temporaries->data[2] = data->scratch_tensor_index + 2;
|
||||
TfLiteTensor* accum_scratch = GetTemporary(context, node, /*index=*/2);
|
||||
TfLiteTensor* accum_scratch;
|
||||
TF_LITE_ENSURE_OK(
|
||||
context, GetTemporarySafe(context, node, /*index=*/2, &accum_scratch));
|
||||
accum_scratch->type = kTfLiteInt32;
|
||||
accum_scratch->allocation_type = kTfLiteArenaRw;
|
||||
int accum_scratch_dims[2] = {num_units, batch_size};
|
||||
@ -250,7 +261,9 @@ TfLiteStatus PrepareImpl(TfLiteContext* context, TfLiteNode* node) {
|
||||
}
|
||||
|
||||
node->temporaries->data[3] = data->scratch_tensor_index + 3;
|
||||
TfLiteTensor* input_offsets = GetTemporary(context, node, /*index=*/3);
|
||||
TfLiteTensor* input_offsets;
|
||||
TF_LITE_ENSURE_OK(
|
||||
context, GetTemporarySafe(context, node, /*index=*/3, &input_offsets));
|
||||
input_offsets->type = kTfLiteInt32;
|
||||
input_offsets->allocation_type = kTfLiteArenaRw;
|
||||
if (!TfLiteIntArrayEqualsArray(input_offsets->dims, 1, scaling_dims)) {
|
||||
@ -260,7 +273,9 @@ TfLiteStatus PrepareImpl(TfLiteContext* context, TfLiteNode* node) {
|
||||
input_offsets_size));
|
||||
}
|
||||
node->temporaries->data[4] = data->scratch_tensor_index + 4;
|
||||
TfLiteTensor* row_sums = GetTemporary(context, node, /*index=*/4);
|
||||
TfLiteTensor* row_sums;
|
||||
TF_LITE_ENSURE_OK(context,
|
||||
GetTemporarySafe(context, node, /*index=*/4, &row_sums));
|
||||
row_sums->type = kTfLiteInt32;
|
||||
row_sums->allocation_type = kTfLiteArenaRwPersistent;
|
||||
int row_sums_dims[1] = {num_units};
|
||||
@ -300,8 +315,11 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
// Check for supported activation types.
|
||||
auto* params =
|
||||
reinterpret_cast<TfLiteFullyConnectedParams*>(node->builtin_data);
|
||||
const TfLiteTensor* filter = GetInput(context, node, kWeightsTensor);
|
||||
const TfLiteTensor* input = GetInput(context, node, kInputTensor);
|
||||
const TfLiteTensor* filter;
|
||||
TF_LITE_ENSURE_OK(context,
|
||||
GetInputSafe(context, node, kWeightsTensor, &filter));
|
||||
const TfLiteTensor* input;
|
||||
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputTensor, &input));
|
||||
const bool is_quantized =
|
||||
((filter->type == kTfLiteUInt8) || (filter->type == kTfLiteInt8));
|
||||
const bool is_hybrid = is_quantized && (input->type == kTfLiteFloat32);
|
||||
@ -484,11 +502,21 @@ TfLiteStatus EvalQuantized(TfLiteContext* context, TfLiteNode* node,
|
||||
int32_t output_offset = output->params.zero_point;
|
||||
// Only the Pie path supports quantized models and float inputs/outputs.
|
||||
if (input->type == kTfLiteFloat32) {
|
||||
TfLiteTensor* input_quantized = GetTemporary(context, node, /*index=*/0);
|
||||
TfLiteTensor* scaling_factors = GetTemporary(context, node, /*index=*/1);
|
||||
TfLiteTensor* accum_scratch = GetTemporary(context, node, /*index=*/2);
|
||||
TfLiteTensor* input_offsets = GetTemporary(context, node, /*index=*/3);
|
||||
TfLiteTensor* row_sums = GetTemporary(context, node, /*index=*/4);
|
||||
TfLiteTensor* input_quantized;
|
||||
TF_LITE_ENSURE_OK(context, GetTemporarySafe(context, node, /*index=*/0,
|
||||
&input_quantized));
|
||||
TfLiteTensor* scaling_factors;
|
||||
TF_LITE_ENSURE_OK(context, GetTemporarySafe(context, node, /*index=*/1,
|
||||
&scaling_factors));
|
||||
TfLiteTensor* accum_scratch;
|
||||
TF_LITE_ENSURE_OK(
|
||||
context, GetTemporarySafe(context, node, /*index=*/2, &accum_scratch));
|
||||
TfLiteTensor* input_offsets;
|
||||
TF_LITE_ENSURE_OK(
|
||||
context, GetTemporarySafe(context, node, /*index=*/3, &input_offsets));
|
||||
TfLiteTensor* row_sums;
|
||||
TF_LITE_ENSURE_OK(context,
|
||||
GetTemporarySafe(context, node, /*index=*/4, &row_sums));
|
||||
return EvalHybrid(context, node, params, data, input, filter, bias,
|
||||
input_quantized, scaling_factors, accum_scratch, row_sums,
|
||||
input_offsets, output);
|
||||
@ -693,13 +721,18 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
||||
reinterpret_cast<TfLiteFullyConnectedParams*>(node->builtin_data);
|
||||
OpData* data = reinterpret_cast<OpData*>(node->user_data);
|
||||
|
||||
const TfLiteTensor* input = GetInput(context, node, kInputTensor);
|
||||
const TfLiteTensor* filter = GetInput(context, node, kWeightsTensor);
|
||||
const TfLiteTensor* input;
|
||||
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputTensor, &input));
|
||||
const TfLiteTensor* filter;
|
||||
TF_LITE_ENSURE_OK(context,
|
||||
GetInputSafe(context, node, kWeightsTensor, &filter));
|
||||
const TfLiteTensor* bias =
|
||||
(node->inputs->size == 3)
|
||||
? GetOptionalInputTensor(context, node, kBiasTensor)
|
||||
: nullptr;
|
||||
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
|
||||
TfLiteTensor* output;
|
||||
TF_LITE_ENSURE_OK(context,
|
||||
GetOutputSafe(context, node, kOutputTensor, &output));
|
||||
|
||||
switch (filter->type) {
|
||||
case kTfLiteFloat32:
|
||||
@ -708,8 +741,10 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
||||
case kTfLiteUInt8:
|
||||
if (params->weights_format ==
|
||||
kTfLiteFullyConnectedWeightsFormatShuffled4x16Int8) {
|
||||
TfLiteTensor* shuffled_input_workspace =
|
||||
GetOutput(context, node, kShuffledInputWorkspaceTensor);
|
||||
TfLiteTensor* shuffled_input_workspace;
|
||||
TF_LITE_ENSURE_OK(
|
||||
context, GetOutputSafe(context, node, kShuffledInputWorkspaceTensor,
|
||||
&shuffled_input_workspace));
|
||||
return EvalShuffledQuantized<kernel_type>(context, node, params, data,
|
||||
input, filter, bias, output,
|
||||
shuffled_input_workspace);
|
||||
|
@ -38,9 +38,14 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
|
||||
const auto* params =
|
||||
reinterpret_cast<const TfLiteGatherParams*>(node->builtin_data);
|
||||
const TfLiteTensor* input = GetInput(context, node, kInputTensor);
|
||||
const TfLiteTensor* positions = GetInput(context, node, kInputPositions);
|
||||
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
|
||||
const TfLiteTensor* input;
|
||||
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputTensor, &input));
|
||||
const TfLiteTensor* positions;
|
||||
TF_LITE_ENSURE_OK(context,
|
||||
GetInputSafe(context, node, kInputPositions, &positions));
|
||||
TfLiteTensor* output;
|
||||
TF_LITE_ENSURE_OK(context,
|
||||
GetOutputSafe(context, node, kOutputTensor, &output));
|
||||
|
||||
switch (positions->type) {
|
||||
case kTfLiteInt64:
|
||||
@ -132,9 +137,14 @@ TfLiteStatus GatherStrings(TfLiteContext* context, const TfLiteTensor* input,
|
||||
TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
||||
const auto* params =
|
||||
reinterpret_cast<const TfLiteGatherParams*>(node->builtin_data);
|
||||
const TfLiteTensor* input = GetInput(context, node, kInputTensor);
|
||||
const TfLiteTensor* positions = GetInput(context, node, kInputPositions);
|
||||
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
|
||||
const TfLiteTensor* input;
|
||||
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputTensor, &input));
|
||||
const TfLiteTensor* positions;
|
||||
TF_LITE_ENSURE_OK(context,
|
||||
GetInputSafe(context, node, kInputPositions, &positions));
|
||||
TfLiteTensor* output;
|
||||
TF_LITE_ENSURE_OK(context,
|
||||
GetOutputSafe(context, node, kOutputTensor, &output));
|
||||
|
||||
if (positions->type == kTfLiteInt32) {
|
||||
switch (input->type) {
|
||||
|
@ -33,9 +33,13 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
TF_LITE_ENSURE_EQ(context, NumInputs(node), 2);
|
||||
TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
|
||||
|
||||
const TfLiteTensor* params = GetInput(context, node, kParams);
|
||||
const TfLiteTensor* indices = GetInput(context, node, kIndices);
|
||||
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
|
||||
const TfLiteTensor* params;
|
||||
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kParams, ¶ms));
|
||||
const TfLiteTensor* indices;
|
||||
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kIndices, &indices));
|
||||
TfLiteTensor* output;
|
||||
TF_LITE_ENSURE_OK(context,
|
||||
GetOutputSafe(context, node, kOutputTensor, &output));
|
||||
|
||||
switch (params->type) {
|
||||
case kTfLiteFloat32:
|
||||
@ -140,9 +144,13 @@ TfLiteStatus EvalGatherNd(TfLiteContext* context, const TfLiteTensor* params,
|
||||
}
|
||||
|
||||
TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
||||
const TfLiteTensor* params = GetInput(context, node, kParams);
|
||||
const TfLiteTensor* indices = GetInput(context, node, kIndices);
|
||||
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
|
||||
const TfLiteTensor* params;
|
||||
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kParams, ¶ms));
|
||||
const TfLiteTensor* indices;
|
||||
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kIndices, &indices));
|
||||
TfLiteTensor* output;
|
||||
TF_LITE_ENSURE_OK(context,
|
||||
GetOutputSafe(context, node, kOutputTensor, &output));
|
||||
|
||||
switch (indices->type) {
|
||||
case kTfLiteInt32:
|
||||
|
@ -37,6 +37,7 @@ limitations under the License.
|
||||
#include <cstring>
|
||||
|
||||
#include "tensorflow/lite/c/common.h"
|
||||
#include "tensorflow/lite/kernels/internal/compatibility.h"
|
||||
#include "tensorflow/lite/kernels/kernel_util.h"
|
||||
#include "tensorflow/lite/string_util.h"
|
||||
|
||||
@ -54,15 +55,18 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
TF_LITE_ENSURE_EQ(context, NumInputs(node), 3);
|
||||
TF_LITE_ENSURE_EQ(context, NumOutputs(node), 2);
|
||||
|
||||
const TfLiteTensor* lookup = GetInput(context, node, 0);
|
||||
const TfLiteTensor* lookup;
|
||||
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 0, &lookup));
|
||||
TF_LITE_ENSURE_EQ(context, NumDimensions(lookup), 1);
|
||||
TF_LITE_ENSURE_EQ(context, lookup->type, kTfLiteInt32);
|
||||
|
||||
const TfLiteTensor* key = GetInput(context, node, 1);
|
||||
const TfLiteTensor* key;
|
||||
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 1, &key));
|
||||
TF_LITE_ENSURE_EQ(context, NumDimensions(key), 1);
|
||||
TF_LITE_ENSURE_EQ(context, key->type, kTfLiteInt32);
|
||||
|
||||
const TfLiteTensor* value = GetInput(context, node, 2);
|
||||
const TfLiteTensor* value;
|
||||
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 2, &value));
|
||||
TF_LITE_ENSURE(context, NumDimensions(value) >= 1);
|
||||
TF_LITE_ENSURE_EQ(context, SizeOfDimension(key, 0),
|
||||
SizeOfDimension(value, 0));
|
||||
@ -70,12 +74,14 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
TF_LITE_ENSURE_EQ(context, NumDimensions(value), 1);
|
||||
}
|
||||
|
||||
TfLiteTensor* hits = GetOutput(context, node, 1);
|
||||
TfLiteTensor* hits;
|
||||
TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, 1, &hits));
|
||||
TF_LITE_ENSURE_EQ(context, hits->type, kTfLiteUInt8);
|
||||
TfLiteIntArray* hitSize = TfLiteIntArrayCreate(1);
|
||||
hitSize->data[0] = SizeOfDimension(lookup, 0);
|
||||
|
||||
TfLiteTensor* output = GetOutput(context, node, 0);
|
||||
TfLiteTensor* output;
|
||||
TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, 0, &output));
|
||||
TF_LITE_ENSURE_EQ(context, value->type, output->type);
|
||||
|
||||
TfLiteStatus status = kTfLiteOk;
|
||||
@ -94,11 +100,16 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
}
|
||||
|
||||
TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
||||
TfLiteTensor* output = GetOutput(context, node, 0);
|
||||
TfLiteTensor* hits = GetOutput(context, node, 1);
|
||||
const TfLiteTensor* lookup = GetInput(context, node, 0);
|
||||
const TfLiteTensor* key = GetInput(context, node, 1);
|
||||
const TfLiteTensor* value = GetInput(context, node, 2);
|
||||
TfLiteTensor* output;
|
||||
TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, 0, &output));
|
||||
TfLiteTensor* hits;
|
||||
TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, 1, &hits));
|
||||
const TfLiteTensor* lookup;
|
||||
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 0, &lookup));
|
||||
const TfLiteTensor* key;
|
||||
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 1, &key));
|
||||
const TfLiteTensor* value;
|
||||
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 2, &value));
|
||||
|
||||
const int num_rows = SizeOfDimension(value, 0);
|
||||
const int row_bytes = value->bytes / num_rows;
|
||||
|
@ -22,6 +22,7 @@ limitations under the License.
|
||||
#include "tensorflow/lite/c/builtin_op_data.h"
|
||||
#include "tensorflow/lite/c/common.h"
|
||||
#include "tensorflow/lite/core/subgraph.h"
|
||||
#include "tensorflow/lite/kernels/internal/compatibility.h"
|
||||
#include "tensorflow/lite/kernels/kernel_util.h"
|
||||
|
||||
namespace tflite {
|
||||
@ -52,7 +53,8 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
TF_LITE_ENSURE(context, node->inputs->size > 0);
|
||||
|
||||
// The first input is the condition.
|
||||
const TfLiteTensor* cond = GetInput(context, node, 0);
|
||||
const TfLiteTensor* cond;
|
||||
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 0, &cond));
|
||||
// Currently only bool is supported.
|
||||
// TODO(ycling): Support other types since TensorFlow also support
|
||||
// non-bool types as condition.
|
||||
@ -83,7 +85,8 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
for (int i = 0; i < num_inputs; ++i) {
|
||||
// The first input of the node is the condition. The indices of the inputs
|
||||
// passed to the subgraphs are offset by 1.
|
||||
const TfLiteTensor* input = GetInput(context, node, i + 1);
|
||||
const TfLiteTensor* input;
|
||||
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, i + 1, &input));
|
||||
std::vector<int> dims(input->dims->data,
|
||||
input->dims->data + input->dims->size);
|
||||
subgraph->ResizeInputTensor(i, dims);
|
||||
@ -113,7 +116,8 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
}
|
||||
|
||||
for (int i = 0; i < num_outputs; ++i) {
|
||||
TfLiteTensor* output = GetOutput(context, node, i);
|
||||
TfLiteTensor* output;
|
||||
TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, i, &output));
|
||||
if (has_dynamic_output_tensors) {
|
||||
SetTensorToDynamic(output);
|
||||
} else {
|
||||
@ -133,7 +137,8 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
||||
const OpData* op_data = reinterpret_cast<OpData*>(node->user_data);
|
||||
|
||||
const TfLiteTensor* cond = GetInput(context, node, 0);
|
||||
const TfLiteTensor* cond;
|
||||
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 0, &cond));
|
||||
bool cond_value = cond->data.b[0];
|
||||
|
||||
Subgraph* this_subgraph = reinterpret_cast<Subgraph*>(context->impl_);
|
||||
@ -147,7 +152,8 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
||||
Subgraph& active_branch_subgraph =
|
||||
*(*subgraphs)[active_branch_subgraph_index];
|
||||
for (int i = 0; i < active_branch_subgraph.inputs().size(); ++i) {
|
||||
const TfLiteTensor* input = GetInput(context, node, i + 1);
|
||||
const TfLiteTensor* input;
|
||||
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, i + 1, &input));
|
||||
TfLiteTensor* subgraph_input =
|
||||
active_branch_subgraph.tensor(active_branch_subgraph.inputs()[i]);
|
||||
TF_LITE_ENSURE_EQ(context, input->bytes, subgraph_input->bytes);
|
||||
@ -164,7 +170,8 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
||||
|
||||
bool has_dynamic_output_tensors = false;
|
||||
for (int i = 0; i < node->outputs->size; ++i) {
|
||||
TfLiteTensor* output = GetOutput(context, node, i);
|
||||
TfLiteTensor* output;
|
||||
TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, i, &output));
|
||||
if (IsDynamicTensor(output)) {
|
||||
has_dynamic_output_tensors = true;
|
||||
break;
|
||||
@ -173,7 +180,8 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
||||
|
||||
if (has_dynamic_output_tensors) {
|
||||
for (int i = 0; i < node->outputs->size; ++i) {
|
||||
TfLiteTensor* output = GetOutput(context, node, i);
|
||||
TfLiteTensor* output;
|
||||
TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, i, &output));
|
||||
TfLiteTensor* subgraph_output =
|
||||
active_branch_subgraph.tensor(active_branch_subgraph.outputs()[i]);
|
||||
TfLiteIntArray* output_size = TfLiteIntArrayCopy(subgraph_output->dims);
|
||||
@ -185,7 +193,8 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
||||
for (int i = 0; i < active_branch_subgraph.outputs().size(); ++i) {
|
||||
const TfLiteTensor* subgraph_output =
|
||||
active_branch_subgraph.tensor(active_branch_subgraph.outputs()[i]);
|
||||
TfLiteTensor* output = GetOutput(context, node, i);
|
||||
TfLiteTensor* output;
|
||||
TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, i, &output));
|
||||
TF_LITE_ENSURE_EQ(context, output->bytes, subgraph_output->bytes);
|
||||
memcpy(output->data.raw, subgraph_output->data.raw, output->bytes);
|
||||
}
|
||||
|
@ -44,8 +44,11 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
TF_LITE_ENSURE_EQ(context, NumInputs(node), 1);
|
||||
TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
|
||||
|
||||
const TfLiteTensor* input = GetInput(context, node, kInputTensor);
|
||||
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
|
||||
const TfLiteTensor* input;
|
||||
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputTensor, &input));
|
||||
TfLiteTensor* output;
|
||||
TF_LITE_ENSURE_OK(context,
|
||||
GetOutputSafe(context, node, kOutputTensor, &output));
|
||||
|
||||
TF_LITE_ENSURE(context, NumDimensions(input) <= 4);
|
||||
|
||||
@ -74,8 +77,11 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
|
||||
template <KernelType kernel_type>
|
||||
TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
||||
const TfLiteTensor* input = GetInput(context, node, kInputTensor);
|
||||
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
|
||||
const TfLiteTensor* input;
|
||||
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputTensor, &input));
|
||||
TfLiteTensor* output;
|
||||
TF_LITE_ENSURE_OK(context,
|
||||
GetOutputSafe(context, node, kOutputTensor, &output));
|
||||
|
||||
// TODO(b/143912164): instead of hardcode the epsilon here, we should read it
|
||||
// from tensorflow, i.e., adding a params.
|
||||
|
@ -39,8 +39,11 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
TF_LITE_ENSURE_EQ(context, NumInputs(node), 1);
|
||||
TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
|
||||
|
||||
const TfLiteTensor* input = GetInput(context, node, kInputTensor);
|
||||
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
|
||||
const TfLiteTensor* input;
|
||||
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputTensor, &input));
|
||||
TfLiteTensor* output;
|
||||
TF_LITE_ENSURE_OK(context,
|
||||
GetOutputSafe(context, node, kOutputTensor, &output));
|
||||
|
||||
TF_LITE_ENSURE_EQ(context, NumDimensions(input), 4);
|
||||
|
||||
@ -61,8 +64,11 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
||||
auto* params =
|
||||
reinterpret_cast<TfLiteLocalResponseNormParams*>(node->builtin_data);
|
||||
|
||||
const TfLiteTensor* input = GetInput(context, node, kInputTensor);
|
||||
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
|
||||
const TfLiteTensor* input;
|
||||
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputTensor, &input));
|
||||
TfLiteTensor* output;
|
||||
TF_LITE_ENSURE_OK(context,
|
||||
GetOutputSafe(context, node, kOutputTensor, &output));
|
||||
|
||||
if (output->type == kTfLiteFloat32) {
|
||||
#define TF_LITE_LOCAL_RESPONSE_NORM(type) \
|
||||
|
@ -54,9 +54,15 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
// Reinterprete the opaque data provided by user.
|
||||
OpData* data = reinterpret_cast<OpData*>(node->user_data);
|
||||
|
||||
const TfLiteTensor* input1 = GetInput(context, node, kInputTensor1);
|
||||
const TfLiteTensor* input2 = GetInput(context, node, kInputTensor2);
|
||||
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
|
||||
const TfLiteTensor* input1;
|
||||
TF_LITE_ENSURE_OK(context,
|
||||
GetInputSafe(context, node, kInputTensor1, &input1));
|
||||
const TfLiteTensor* input2;
|
||||
TF_LITE_ENSURE_OK(context,
|
||||
GetInputSafe(context, node, kInputTensor2, &input2));
|
||||
TfLiteTensor* output;
|
||||
TF_LITE_ENSURE_OK(context,
|
||||
GetOutputSafe(context, node, kOutputTensor, &output));
|
||||
|
||||
TF_LITE_ENSURE_TYPES_EQ(context, input1->type, input2->type);
|
||||
|
||||
@ -84,9 +90,15 @@ TfLiteStatus LogicalImpl(TfLiteContext* context, TfLiteNode* node,
|
||||
bool (*func)(bool, bool)) {
|
||||
OpData* data = reinterpret_cast<OpData*>(node->user_data);
|
||||
|
||||
const TfLiteTensor* input1 = GetInput(context, node, kInputTensor1);
|
||||
const TfLiteTensor* input2 = GetInput(context, node, kInputTensor2);
|
||||
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
|
||||
const TfLiteTensor* input1;
|
||||
TF_LITE_ENSURE_OK(context,
|
||||
GetInputSafe(context, node, kInputTensor1, &input1));
|
||||
const TfLiteTensor* input2;
|
||||
TF_LITE_ENSURE_OK(context,
|
||||
GetInputSafe(context, node, kInputTensor2, &input2));
|
||||
TfLiteTensor* output;
|
||||
TF_LITE_ENSURE_OK(context,
|
||||
GetOutputSafe(context, node, kOutputTensor, &output));
|
||||
|
||||
if (data->requires_broadcast) {
|
||||
reference_ops::BroadcastBinaryFunction4DSlow<bool, bool, bool>(
|
||||
|
@ -73,22 +73,26 @@ TfLiteStatus Resize(TfLiteContext* context, TfLiteNode* node) {
|
||||
TF_LITE_ENSURE(context, NumInputs(node) == 2 || NumInputs(node) == 3);
|
||||
TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
|
||||
|
||||
const TfLiteTensor* hash = GetInput(context, node, 0);
|
||||
const TfLiteTensor* hash;
|
||||
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 0, &hash));
|
||||
TF_LITE_ENSURE_EQ(context, NumDimensions(hash), 2);
|
||||
// Support up to 32 bits.
|
||||
TF_LITE_ENSURE(context, SizeOfDimension(hash, 1) <= 32);
|
||||
|
||||
const TfLiteTensor* input = GetInput(context, node, 1);
|
||||
const TfLiteTensor* input;
|
||||
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 1, &input));
|
||||
TF_LITE_ENSURE(context, NumDimensions(input) >= 1);
|
||||
|
||||
if (NumInputs(node) == 3) {
|
||||
const TfLiteTensor* weight = GetInput(context, node, 2);
|
||||
const TfLiteTensor* weight;
|
||||
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 2, &weight));
|
||||
TF_LITE_ENSURE_EQ(context, NumDimensions(weight), 1);
|
||||
TF_LITE_ENSURE_EQ(context, SizeOfDimension(weight, 0),
|
||||
SizeOfDimension(input, 0));
|
||||
}
|
||||
|
||||
TfLiteTensor* output = GetOutput(context, node, 0);
|
||||
TfLiteTensor* output;
|
||||
TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, 0, &output));
|
||||
TfLiteIntArray* outputSize = TfLiteIntArrayCreate(1);
|
||||
switch (params->type) {
|
||||
case kTfLiteLshProjectionSparse:
|
||||
@ -170,9 +174,13 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
||||
auto* params =
|
||||
reinterpret_cast<TfLiteLSHProjectionParams*>(node->builtin_data);
|
||||
|
||||
int32_t* out_buf = GetOutput(context, node, 0)->data.i32;
|
||||
const TfLiteTensor* hash = GetInput(context, node, 0);
|
||||
const TfLiteTensor* input = GetInput(context, node, 1);
|
||||
TfLiteTensor* out_tensor;
|
||||
TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, 0, &out_tensor));
|
||||
int32_t* out_buf = out_tensor->data.i32;
|
||||
const TfLiteTensor* hash;
|
||||
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 0, &hash));
|
||||
const TfLiteTensor* input;
|
||||
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 1, &input));
|
||||
const TfLiteTensor* weight =
|
||||
NumInputs(node) == 2 ? nullptr : GetInput(context, node, 2);
|
||||
|
||||
|
@ -149,7 +149,9 @@ TfLiteStatus PopulateQuantizedLstmParams8x8_16(
|
||||
const TfLiteTensor* cell_state =
|
||||
GetVariableInput(context, node, kCellStateTensor);
|
||||
TF_LITE_ENSURE(context, cell_state != nullptr);
|
||||
const TfLiteTensor* output_tensor = GetOutput(context, node, kOutputTensor);
|
||||
TfLiteTensor* output_tensor;
|
||||
TF_LITE_ENSURE_OK(
|
||||
context, GetOutputSafe(context, node, kOutputTensor, &output_tensor));
|
||||
|
||||
auto* cell_state_params =
|
||||
static_cast<TfLiteAffineQuantization*>(cell_state->quantization.params);
|
||||
@ -173,25 +175,38 @@ TfLiteStatus PopulateQuantizedLstmParams8x8_16(
|
||||
OpData* op_data = static_cast<OpData*>(node->user_data);
|
||||
const bool use_layer_norm = op_data->use_layer_norm;
|
||||
|
||||
const TfLiteTensor* input = GetInput(context, node, kInputTensor);
|
||||
const TfLiteTensor* input;
|
||||
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputTensor, &input));
|
||||
|
||||
const TfLiteTensor* input_to_input_weights =
|
||||
GetOptionalInputTensor(context, node, kInputToInputWeightsTensor);
|
||||
const TfLiteTensor* input_to_forget_weights =
|
||||
GetInput(context, node, kInputToForgetWeightsTensor);
|
||||
const TfLiteTensor* input_to_cell_weights =
|
||||
GetInput(context, node, kInputToCellWeightsTensor);
|
||||
const TfLiteTensor* input_to_output_weights =
|
||||
GetInput(context, node, kInputToOutputWeightsTensor);
|
||||
const TfLiteTensor* input_to_forget_weights;
|
||||
TF_LITE_ENSURE_OK(context,
|
||||
GetInputSafe(context, node, kInputToForgetWeightsTensor,
|
||||
&input_to_forget_weights));
|
||||
const TfLiteTensor* input_to_cell_weights;
|
||||
TF_LITE_ENSURE_OK(context,
|
||||
GetInputSafe(context, node, kInputToCellWeightsTensor,
|
||||
&input_to_cell_weights));
|
||||
const TfLiteTensor* input_to_output_weights;
|
||||
TF_LITE_ENSURE_OK(context,
|
||||
GetInputSafe(context, node, kInputToOutputWeightsTensor,
|
||||
&input_to_output_weights));
|
||||
|
||||
const TfLiteTensor* recurrent_to_input_weights =
|
||||
GetOptionalInputTensor(context, node, kRecurrentToInputWeightsTensor);
|
||||
const TfLiteTensor* recurrent_to_forget_weights =
|
||||
GetInput(context, node, kRecurrentToForgetWeightsTensor);
|
||||
const TfLiteTensor* recurrent_to_cell_weights =
|
||||
GetInput(context, node, kRecurrentToCellWeightsTensor);
|
||||
const TfLiteTensor* recurrent_to_output_weights =
|
||||
GetInput(context, node, kRecurrentToOutputWeightsTensor);
|
||||
const TfLiteTensor* recurrent_to_forget_weights;
|
||||
TF_LITE_ENSURE_OK(context,
|
||||
GetInputSafe(context, node, kRecurrentToForgetWeightsTensor,
|
||||
&recurrent_to_forget_weights));
|
||||
const TfLiteTensor* recurrent_to_cell_weights;
|
||||
TF_LITE_ENSURE_OK(context,
|
||||
GetInputSafe(context, node, kRecurrentToCellWeightsTensor,
|
||||
&recurrent_to_cell_weights));
|
||||
const TfLiteTensor* recurrent_to_output_weights;
|
||||
TF_LITE_ENSURE_OK(context,
|
||||
GetInputSafe(context, node, kRecurrentToOutputWeightsTensor,
|
||||
&recurrent_to_output_weights));
|
||||
|
||||
const TfLiteTensor* cell_to_input_weights =
|
||||
GetOptionalInputTensor(context, node, kCellToInputWeightsTensor);
|
||||
@ -227,7 +242,9 @@ TfLiteStatus PopulateQuantizedLstmParams8x8_16(
|
||||
std::vector<int32> intermediate_zp;
|
||||
for (int i = 0; i < 4; ++i) {
|
||||
if (use_layer_norm) {
|
||||
const TfLiteTensor* intermediate = GetIntermediates(context, node, i);
|
||||
TfLiteTensor* intermediate;
|
||||
TF_LITE_ENSURE_OK(context,
|
||||
GetIntermediatesSafe(context, node, i, &intermediate));
|
||||
auto* params = static_cast<TfLiteAffineQuantization*>(
|
||||
intermediate->quantization.params);
|
||||
intermediate_scale.push_back(params->scale->data[0]);
|
||||
@ -240,7 +257,8 @@ TfLiteStatus PopulateQuantizedLstmParams8x8_16(
|
||||
}
|
||||
// In the absense of projection, hidden becomes otuput and this intermediate
|
||||
// is ignored.
|
||||
const TfLiteTensor* hidden = GetIntermediates(context, node, 4);
|
||||
TfLiteTensor* hidden;
|
||||
TF_LITE_ENSURE_OK(context, GetIntermediatesSafe(context, node, 4, &hidden));
|
||||
auto* hidden_params =
|
||||
static_cast<TfLiteAffineQuantization*>(hidden->quantization.params);
|
||||
intermediate_scale.push_back(hidden_params->scale->data[0]);
|
||||
@ -446,24 +464,37 @@ TfLiteStatus PopulateQuantizedLstmParams8x8_8(
|
||||
TfLiteContext* context, TfLiteNode* node,
|
||||
lstm_eval::IntegerLstmParameter* integer_lstm_param) {
|
||||
// Get all tensors.
|
||||
const TfLiteTensor* input = GetInput(context, node, kInputTensor);
|
||||
const TfLiteTensor* input;
|
||||
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputTensor, &input));
|
||||
const TfLiteTensor* input_to_input_weights =
|
||||
GetOptionalInputTensor(context, node, kInputToInputWeightsTensor);
|
||||
const TfLiteTensor* input_to_forget_weights =
|
||||
GetInput(context, node, kInputToForgetWeightsTensor);
|
||||
const TfLiteTensor* input_to_cell_weights =
|
||||
GetInput(context, node, kInputToCellWeightsTensor);
|
||||
const TfLiteTensor* input_to_output_weights =
|
||||
GetInput(context, node, kInputToOutputWeightsTensor);
|
||||
const TfLiteTensor* input_to_forget_weights;
|
||||
TF_LITE_ENSURE_OK(context,
|
||||
GetInputSafe(context, node, kInputToForgetWeightsTensor,
|
||||
&input_to_forget_weights));
|
||||
const TfLiteTensor* input_to_cell_weights;
|
||||
TF_LITE_ENSURE_OK(context,
|
||||
GetInputSafe(context, node, kInputToCellWeightsTensor,
|
||||
&input_to_cell_weights));
|
||||
const TfLiteTensor* input_to_output_weights;
|
||||
TF_LITE_ENSURE_OK(context,
|
||||
GetInputSafe(context, node, kInputToOutputWeightsTensor,
|
||||
&input_to_output_weights));
|
||||
|
||||
const TfLiteTensor* recurrent_to_input_weights =
|
||||
GetOptionalInputTensor(context, node, kRecurrentToInputWeightsTensor);
|
||||
const TfLiteTensor* recurrent_to_forget_weights =
|
||||
GetInput(context, node, kRecurrentToForgetWeightsTensor);
|
||||
const TfLiteTensor* recurrent_to_cell_weights =
|
||||
GetInput(context, node, kRecurrentToCellWeightsTensor);
|
||||
const TfLiteTensor* recurrent_to_output_weights =
|
||||
GetInput(context, node, kRecurrentToOutputWeightsTensor);
|
||||
const TfLiteTensor* recurrent_to_forget_weights;
|
||||
TF_LITE_ENSURE_OK(context,
|
||||
GetInputSafe(context, node, kRecurrentToForgetWeightsTensor,
|
||||
&recurrent_to_forget_weights));
|
||||
const TfLiteTensor* recurrent_to_cell_weights;
|
||||
TF_LITE_ENSURE_OK(context,
|
||||
GetInputSafe(context, node, kRecurrentToCellWeightsTensor,
|
||||
&recurrent_to_cell_weights));
|
||||
const TfLiteTensor* recurrent_to_output_weights;
|
||||
TF_LITE_ENSURE_OK(context,
|
||||
GetInputSafe(context, node, kRecurrentToOutputWeightsTensor,
|
||||
&recurrent_to_output_weights));
|
||||
|
||||
const TfLiteTensor* cell_to_input_weights =
|
||||
GetOptionalInputTensor(context, node, kCellToInputWeightsTensor);
|
||||
@ -483,12 +514,15 @@ TfLiteStatus PopulateQuantizedLstmParams8x8_8(
|
||||
|
||||
const TfLiteTensor* input_gate_bias =
|
||||
GetOptionalInputTensor(context, node, kInputGateBiasTensor);
|
||||
const TfLiteTensor* forget_gate_bias =
|
||||
GetInput(context, node, kForgetGateBiasTensor);
|
||||
const TfLiteTensor* cell_gate_bias =
|
||||
GetInput(context, node, kCellGateBiasTensor);
|
||||
const TfLiteTensor* output_gate_bias =
|
||||
GetInput(context, node, kOutputGateBiasTensor);
|
||||
const TfLiteTensor* forget_gate_bias;
|
||||
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kForgetGateBiasTensor,
|
||||
&forget_gate_bias));
|
||||
const TfLiteTensor* cell_gate_bias;
|
||||
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kCellGateBiasTensor,
|
||||
&cell_gate_bias));
|
||||
const TfLiteTensor* output_gate_bias;
|
||||
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kOutputGateBiasTensor,
|
||||
&output_gate_bias));
|
||||
|
||||
const TfLiteTensor* projection_weights =
|
||||
GetOptionalInputTensor(context, node, kProjectionWeightsTensor);
|
||||
@ -774,7 +808,9 @@ TfLiteStatus PopulateQuantizedLstmParams8x8_8(
|
||||
const float cell_clip = params->cell_clip;
|
||||
const float proj_clip = params->proj_clip;
|
||||
|
||||
const TfLiteTensor* output_tensor = GetOutput(context, node, kOutputTensor);
|
||||
TfLiteTensor* output_tensor;
|
||||
TF_LITE_ENSURE_OK(
|
||||
context, GetOutputSafe(context, node, kOutputTensor, &output_tensor));
|
||||
|
||||
auto* cell_state_params = reinterpret_cast<TfLiteAffineQuantization*>(
|
||||
cell_state->quantization.params);
|
||||
@ -825,8 +861,10 @@ TfLiteStatus CheckInputTensorDimensions(TfLiteContext* context,
|
||||
TF_LITE_ENSURE(context, params->cell_clip >= 0);
|
||||
TF_LITE_ENSURE(context, params->proj_clip >= 0);
|
||||
|
||||
const TfLiteTensor* input_to_forget_weights =
|
||||
GetInput(context, node, kInputToForgetWeightsTensor);
|
||||
const TfLiteTensor* input_to_forget_weights;
|
||||
TF_LITE_ENSURE_OK(context,
|
||||
GetInputSafe(context, node, kInputToForgetWeightsTensor,
|
||||
&input_to_forget_weights));
|
||||
TF_LITE_ENSURE_EQ(context, input_to_forget_weights->dims->size, 2);
|
||||
TF_LITE_ENSURE_EQ(context, input_to_forget_weights->dims->data[0], n_cell);
|
||||
TF_LITE_ENSURE_EQ(context, input_to_forget_weights->dims->data[1], n_input);
|
||||
@ -845,8 +883,10 @@ TfLiteStatus CheckInputTensorDimensions(TfLiteContext* context,
|
||||
input_to_forget_weights->type);
|
||||
}
|
||||
|
||||
const TfLiteTensor* input_to_cell_weights =
|
||||
GetInput(context, node, kInputToCellWeightsTensor);
|
||||
const TfLiteTensor* input_to_cell_weights;
|
||||
TF_LITE_ENSURE_OK(context,
|
||||
GetInputSafe(context, node, kInputToCellWeightsTensor,
|
||||
&input_to_cell_weights));
|
||||
TF_LITE_ENSURE_EQ(context, input_to_cell_weights->dims->size, 2);
|
||||
TF_LITE_ENSURE_EQ(context, input_to_cell_weights->dims->data[0], n_cell);
|
||||
TF_LITE_ENSURE_EQ(context, input_to_cell_weights->dims->data[1], n_input);
|
||||
@ -865,8 +905,10 @@ TfLiteStatus CheckInputTensorDimensions(TfLiteContext* context,
|
||||
input_to_forget_weights->type);
|
||||
}
|
||||
|
||||
const TfLiteTensor* recurrent_to_forget_weights =
|
||||
GetInput(context, node, kRecurrentToForgetWeightsTensor);
|
||||
const TfLiteTensor* recurrent_to_forget_weights;
|
||||
TF_LITE_ENSURE_OK(context,
|
||||
GetInputSafe(context, node, kRecurrentToForgetWeightsTensor,
|
||||
&recurrent_to_forget_weights));
|
||||
TF_LITE_ENSURE_EQ(context, recurrent_to_forget_weights->dims->size, 2);
|
||||
TF_LITE_ENSURE_EQ(context, recurrent_to_forget_weights->dims->data[0],
|
||||
n_cell);
|
||||
@ -875,8 +917,10 @@ TfLiteStatus CheckInputTensorDimensions(TfLiteContext* context,
|
||||
TF_LITE_ENSURE_TYPES_EQ(context, recurrent_to_forget_weights->type,
|
||||
input_to_forget_weights->type);
|
||||
|
||||
const TfLiteTensor* recurrent_to_cell_weights =
|
||||
GetInput(context, node, kRecurrentToCellWeightsTensor);
|
||||
const TfLiteTensor* recurrent_to_cell_weights;
|
||||
TF_LITE_ENSURE_OK(context,
|
||||
GetInputSafe(context, node, kRecurrentToCellWeightsTensor,
|
||||
&recurrent_to_cell_weights));
|
||||
TF_LITE_ENSURE_EQ(context, recurrent_to_cell_weights->dims->size, 2);
|
||||
TF_LITE_ENSURE_EQ(context, recurrent_to_cell_weights->dims->data[0], n_cell);
|
||||
TF_LITE_ENSURE_EQ(context, recurrent_to_cell_weights->dims->data[1],
|
||||
@ -948,8 +992,9 @@ TfLiteStatus CheckInputTensorDimensions(TfLiteContext* context,
|
||||
}
|
||||
}
|
||||
|
||||
const TfLiteTensor* forget_gate_bias =
|
||||
GetInput(context, node, kForgetGateBiasTensor);
|
||||
const TfLiteTensor* forget_gate_bias;
|
||||
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kForgetGateBiasTensor,
|
||||
&forget_gate_bias));
|
||||
TF_LITE_ENSURE_EQ(context, forget_gate_bias->dims->size, 1);
|
||||
TF_LITE_ENSURE_EQ(context, forget_gate_bias->dims->data[0], n_cell);
|
||||
if (is_integer) {
|
||||
@ -958,8 +1003,9 @@ TfLiteStatus CheckInputTensorDimensions(TfLiteContext* context,
|
||||
TF_LITE_ENSURE_TYPES_EQ(context, forget_gate_bias->type, kTfLiteFloat32);
|
||||
}
|
||||
|
||||
const TfLiteTensor* cell_gate_bias =
|
||||
GetInput(context, node, kCellGateBiasTensor);
|
||||
const TfLiteTensor* cell_gate_bias;
|
||||
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kCellGateBiasTensor,
|
||||
&cell_gate_bias));
|
||||
TF_LITE_ENSURE_EQ(context, cell_gate_bias->dims->size, 1);
|
||||
TF_LITE_ENSURE_EQ(context, cell_gate_bias->dims->data[0], n_cell);
|
||||
if (is_integer) {
|
||||
@ -968,8 +1014,9 @@ TfLiteStatus CheckInputTensorDimensions(TfLiteContext* context,
|
||||
TF_LITE_ENSURE_TYPES_EQ(context, cell_gate_bias->type, kTfLiteFloat32);
|
||||
}
|
||||
|
||||
const TfLiteTensor* output_gate_bias =
|
||||
GetInput(context, node, kOutputGateBiasTensor);
|
||||
const TfLiteTensor* output_gate_bias;
|
||||
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kOutputGateBiasTensor,
|
||||
&output_gate_bias));
|
||||
TF_LITE_ENSURE_EQ(context, output_gate_bias->dims->size, 1);
|
||||
TF_LITE_ENSURE_EQ(context, output_gate_bias->dims->data[0], n_cell);
|
||||
if (is_integer) {
|
||||
@ -1105,7 +1152,8 @@ TfLiteStatus PrecomputeZeroPointTimesWeightWithBias(
|
||||
TfLiteStatus PopulatePrecomputedZPTimesWeightsWithBias(TfLiteContext* context,
|
||||
OpData* op_data,
|
||||
TfLiteNode* node) {
|
||||
const TfLiteTensor* input = GetInput(context, node, kInputTensor);
|
||||
const TfLiteTensor* input;
|
||||
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputTensor, &input));
|
||||
const TfLiteTensor* output_state =
|
||||
GetVariableInput(context, node, kOutputStateTensor);
|
||||
TF_LITE_ENSURE(context, output_state != nullptr);
|
||||
@ -1115,21 +1163,33 @@ TfLiteStatus PopulatePrecomputedZPTimesWeightsWithBias(TfLiteContext* context,
|
||||
|
||||
const TfLiteTensor* input_to_input_weights =
|
||||
GetOptionalInputTensor(context, node, kInputToInputWeightsTensor);
|
||||
const TfLiteTensor* input_to_forget_weights =
|
||||
GetInput(context, node, kInputToForgetWeightsTensor);
|
||||
const TfLiteTensor* input_to_cell_weights =
|
||||
GetInput(context, node, kInputToCellWeightsTensor);
|
||||
const TfLiteTensor* input_to_output_weights =
|
||||
GetInput(context, node, kInputToOutputWeightsTensor);
|
||||
const TfLiteTensor* input_to_forget_weights;
|
||||
TF_LITE_ENSURE_OK(context,
|
||||
GetInputSafe(context, node, kInputToForgetWeightsTensor,
|
||||
&input_to_forget_weights));
|
||||
const TfLiteTensor* input_to_cell_weights;
|
||||
TF_LITE_ENSURE_OK(context,
|
||||
GetInputSafe(context, node, kInputToCellWeightsTensor,
|
||||
&input_to_cell_weights));
|
||||
const TfLiteTensor* input_to_output_weights;
|
||||
TF_LITE_ENSURE_OK(context,
|
||||
GetInputSafe(context, node, kInputToOutputWeightsTensor,
|
||||
&input_to_output_weights));
|
||||
|
||||
const TfLiteTensor* recurrent_to_input_weights =
|
||||
GetOptionalInputTensor(context, node, kRecurrentToInputWeightsTensor);
|
||||
const TfLiteTensor* recurrent_to_forget_weights =
|
||||
GetInput(context, node, kRecurrentToForgetWeightsTensor);
|
||||
const TfLiteTensor* recurrent_to_cell_weights =
|
||||
GetInput(context, node, kRecurrentToCellWeightsTensor);
|
||||
const TfLiteTensor* recurrent_to_output_weights =
|
||||
GetInput(context, node, kRecurrentToOutputWeightsTensor);
|
||||
const TfLiteTensor* recurrent_to_forget_weights;
|
||||
TF_LITE_ENSURE_OK(context,
|
||||
GetInputSafe(context, node, kRecurrentToForgetWeightsTensor,
|
||||
&recurrent_to_forget_weights));
|
||||
const TfLiteTensor* recurrent_to_cell_weights;
|
||||
TF_LITE_ENSURE_OK(context,
|
||||
GetInputSafe(context, node, kRecurrentToCellWeightsTensor,
|
||||
&recurrent_to_cell_weights));
|
||||
const TfLiteTensor* recurrent_to_output_weights;
|
||||
TF_LITE_ENSURE_OK(context,
|
||||
GetInputSafe(context, node, kRecurrentToOutputWeightsTensor,
|
||||
&recurrent_to_output_weights));
|
||||
|
||||
const TfLiteTensor* projection_weights =
|
||||
GetOptionalInputTensor(context, node, kProjectionWeightsTensor);
|
||||
@ -1254,20 +1314,25 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
|
||||
// Inferring batch size, number of outputs and number of cells from the
|
||||
// input tensors.
|
||||
const TfLiteTensor* input = GetInput(context, node, kInputTensor);
|
||||
const TfLiteTensor* input;
|
||||
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputTensor, &input));
|
||||
const bool is_integer = input->type == kTfLiteInt8;
|
||||
TF_LITE_ENSURE(context, input->dims->size > 1);
|
||||
const int n_batch = input->dims->data[0];
|
||||
const int n_input = input->dims->data[1];
|
||||
|
||||
const TfLiteTensor* input_to_output_weights =
|
||||
GetInput(context, node, kInputToOutputWeightsTensor);
|
||||
const TfLiteTensor* input_to_output_weights;
|
||||
TF_LITE_ENSURE_OK(context,
|
||||
GetInputSafe(context, node, kInputToOutputWeightsTensor,
|
||||
&input_to_output_weights));
|
||||
const int n_cell = input_to_output_weights->dims->data[0];
|
||||
TF_LITE_ENSURE_EQ(context, input_to_output_weights->dims->size, 2);
|
||||
TF_LITE_ENSURE_EQ(context, input_to_output_weights->dims->data[1], n_input);
|
||||
|
||||
const TfLiteTensor* recurrent_to_output_weights =
|
||||
GetInput(context, node, kRecurrentToOutputWeightsTensor);
|
||||
const TfLiteTensor* recurrent_to_output_weights;
|
||||
TF_LITE_ENSURE_OK(context,
|
||||
GetInputSafe(context, node, kRecurrentToOutputWeightsTensor,
|
||||
&recurrent_to_output_weights));
|
||||
TF_LITE_ENSURE_EQ(context, recurrent_to_output_weights->dims->size, 2);
|
||||
TF_LITE_ENSURE_EQ(context, recurrent_to_output_weights->dims->data[0],
|
||||
n_cell);
|
||||
@ -1279,7 +1344,9 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
n_cell, use_layer_norm, is_integer));
|
||||
|
||||
// Get the pointer to output, output_state and cell_state tensors.
|
||||
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
|
||||
TfLiteTensor* output;
|
||||
TF_LITE_ENSURE_OK(context,
|
||||
GetOutputSafe(context, node, kOutputTensor, &output));
|
||||
|
||||
TfLiteTensor* output_state =
|
||||
GetVariableInput(context, node, kOutputStateTensor);
|
||||
@ -1339,7 +1406,9 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
if (!is_integer) {
|
||||
node->temporaries->data[kScratchBuffer] =
|
||||
op_data->scratch_tensor_index + kScratchBuffer;
|
||||
TfLiteTensor* scratch_buffer = GetTemporary(context, node, kScratchBuffer);
|
||||
TfLiteTensor* scratch_buffer;
|
||||
TF_LITE_ENSURE_OK(context, GetTemporarySafe(context, node, kScratchBuffer,
|
||||
&scratch_buffer));
|
||||
scratch_buffer->type = input->type;
|
||||
scratch_buffer->allocation_type = kTfLiteArenaRw;
|
||||
|
||||
@ -1367,8 +1436,9 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
// output_state and cell_state tensors.
|
||||
node->temporaries->data[kInputQuantized] =
|
||||
op_data->scratch_tensor_index + kInputQuantized;
|
||||
TfLiteTensor* input_quantized =
|
||||
GetTemporary(context, node, kInputQuantized);
|
||||
TfLiteTensor* input_quantized;
|
||||
TF_LITE_ENSURE_OK(context, GetTemporarySafe(context, node, kInputQuantized,
|
||||
&input_quantized));
|
||||
input_quantized->type = input_to_output_weights->type;
|
||||
input_quantized->allocation_type = kTfLiteArenaRw;
|
||||
if (!TfLiteIntArrayEqual(input_quantized->dims, input->dims)) {
|
||||
@ -1378,8 +1448,10 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
}
|
||||
node->temporaries->data[kOutputStateQuantized] =
|
||||
op_data->scratch_tensor_index + kOutputStateQuantized;
|
||||
TfLiteTensor* output_state_quantized =
|
||||
GetTemporary(context, node, kOutputStateQuantized);
|
||||
TfLiteTensor* output_state_quantized;
|
||||
TF_LITE_ENSURE_OK(context,
|
||||
GetTemporarySafe(context, node, kOutputStateQuantized,
|
||||
&output_state_quantized));
|
||||
output_state_quantized->type = input_to_output_weights->type;
|
||||
output_state_quantized->allocation_type = kTfLiteArenaRw;
|
||||
if (!TfLiteIntArrayEqual(output_state_quantized->dims,
|
||||
@ -1392,8 +1464,10 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
}
|
||||
node->temporaries->data[kCellStateQuantized] =
|
||||
op_data->scratch_tensor_index + kCellStateQuantized;
|
||||
TfLiteTensor* cell_state_quantized =
|
||||
GetTemporary(context, node, kCellStateQuantized);
|
||||
TfLiteTensor* cell_state_quantized;
|
||||
TF_LITE_ENSURE_OK(context,
|
||||
GetTemporarySafe(context, node, kCellStateQuantized,
|
||||
&cell_state_quantized));
|
||||
cell_state_quantized->type = input_to_output_weights->type;
|
||||
cell_state_quantized->allocation_type = kTfLiteArenaRw;
|
||||
if (!TfLiteIntArrayEqual(cell_state_quantized->dims, cell_state->dims)) {
|
||||
@ -1410,7 +1484,10 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
// the scaling factor of the matrix).
|
||||
node->temporaries->data[kInputScalingFactors] =
|
||||
op_data->scratch_tensor_index + kInputScalingFactors;
|
||||
TfLiteTensor* input_sf = GetTemporary(context, node, kInputScalingFactors);
|
||||
TfLiteTensor* input_sf;
|
||||
TF_LITE_ENSURE_OK(
|
||||
context,
|
||||
GetTemporarySafe(context, node, kInputScalingFactors, &input_sf));
|
||||
input_sf->type = kTfLiteFloat32;
|
||||
input_sf->allocation_type = kTfLiteArenaRw;
|
||||
int scaling_dims[1] = {n_batch};
|
||||
@ -1422,8 +1499,10 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
}
|
||||
node->temporaries->data[kOutputStateScalingFactors] =
|
||||
op_data->scratch_tensor_index + kOutputStateScalingFactors;
|
||||
TfLiteTensor* output_state_sf =
|
||||
GetTemporary(context, node, kOutputStateScalingFactors);
|
||||
TfLiteTensor* output_state_sf;
|
||||
TF_LITE_ENSURE_OK(
|
||||
context, GetTemporarySafe(context, node, kOutputStateScalingFactors,
|
||||
&output_state_sf));
|
||||
output_state_sf->type = kTfLiteFloat32;
|
||||
output_state_sf->allocation_type = kTfLiteArenaRw;
|
||||
if (!TfLiteIntArrayEqualsArray(output_state_sf->dims, 1, scaling_dims)) {
|
||||
@ -1434,8 +1513,10 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
}
|
||||
node->temporaries->data[kProductScalingFactors] =
|
||||
op_data->scratch_tensor_index + kProductScalingFactors;
|
||||
TfLiteTensor* prod_scaling_factors =
|
||||
GetTemporary(context, node, kProductScalingFactors);
|
||||
TfLiteTensor* prod_scaling_factors;
|
||||
TF_LITE_ENSURE_OK(context,
|
||||
GetTemporarySafe(context, node, kProductScalingFactors,
|
||||
&prod_scaling_factors));
|
||||
prod_scaling_factors->type = kTfLiteFloat32;
|
||||
prod_scaling_factors->allocation_type = kTfLiteArenaRw;
|
||||
if (!TfLiteIntArrayEqualsArray(prod_scaling_factors->dims, 1,
|
||||
@ -1451,8 +1532,10 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
// this is used for diagonal matrices, only need to store n_cell values.
|
||||
node->temporaries->data[kRecoveredCellWeights] =
|
||||
op_data->scratch_tensor_index + kRecoveredCellWeights;
|
||||
TfLiteTensor* recovered_cell_weights =
|
||||
GetTemporary(context, node, kRecoveredCellWeights);
|
||||
TfLiteTensor* recovered_cell_weights;
|
||||
TF_LITE_ENSURE_OK(context,
|
||||
GetTemporarySafe(context, node, kRecoveredCellWeights,
|
||||
&recovered_cell_weights));
|
||||
recovered_cell_weights->type = kTfLiteFloat32;
|
||||
recovered_cell_weights->allocation_type = kTfLiteArenaRw;
|
||||
int recovered_cell_dims[1] = {n_cell};
|
||||
@ -1468,7 +1551,9 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
// multiplication before multiplication by scaling factor
|
||||
node->temporaries->data[kAccumScratch] =
|
||||
op_data->scratch_tensor_index + kAccumScratch;
|
||||
TfLiteTensor* accum_scratch = GetTemporary(context, node, kAccumScratch);
|
||||
TfLiteTensor* accum_scratch;
|
||||
TF_LITE_ENSURE_OK(context, GetTemporarySafe(context, node, kAccumScratch,
|
||||
&accum_scratch));
|
||||
accum_scratch->type = kTfLiteInt32;
|
||||
accum_scratch->allocation_type = kTfLiteArenaRw;
|
||||
int accum_scratch_dims[2] = {n_cell, n_batch};
|
||||
@ -1482,7 +1567,9 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
}
|
||||
node->temporaries->data[kInputZeroPoints] =
|
||||
op_data->scratch_tensor_index + kInputZeroPoints;
|
||||
TfLiteTensor* input_zp = GetTemporary(context, node, kInputZeroPoints);
|
||||
TfLiteTensor* input_zp;
|
||||
TF_LITE_ENSURE_OK(
|
||||
context, GetTemporarySafe(context, node, kInputZeroPoints, &input_zp));
|
||||
input_zp->type = kTfLiteFloat32;
|
||||
input_zp->allocation_type = kTfLiteArenaRw;
|
||||
if (!TfLiteIntArrayEqualsArray(input_zp->dims, 1, scaling_dims)) {
|
||||
@ -1493,8 +1580,10 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
}
|
||||
node->temporaries->data[kOutputStateZeroPoints] =
|
||||
op_data->scratch_tensor_index + kOutputStateZeroPoints;
|
||||
TfLiteTensor* output_state_zp =
|
||||
GetTemporary(context, node, kOutputStateZeroPoints);
|
||||
TfLiteTensor* output_state_zp;
|
||||
TF_LITE_ENSURE_OK(context,
|
||||
GetTemporarySafe(context, node, kOutputStateZeroPoints,
|
||||
&output_state_zp));
|
||||
output_state_zp->type = kTfLiteFloat32;
|
||||
output_state_zp->allocation_type = kTfLiteArenaRw;
|
||||
if (!TfLiteIntArrayEqualsArray(output_state_zp->dims, 1, scaling_dims)) {
|
||||
@ -1516,7 +1605,9 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
row_sums_rows += ceil(static_cast<float>(n_output) / n_cell);
|
||||
}
|
||||
|
||||
TfLiteTensor* row_sums = GetTemporary(context, node, kRowSums);
|
||||
TfLiteTensor* row_sums;
|
||||
TF_LITE_ENSURE_OK(context,
|
||||
GetTemporarySafe(context, node, kRowSums, &row_sums));
|
||||
row_sums->type = kTfLiteInt32;
|
||||
row_sums->allocation_type = kTfLiteArenaRwPersistent;
|
||||
const int row_sums_dims[2] = {row_sums_rows, n_cell};
|
||||
@ -1664,8 +1755,10 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
for (int scratch_index = 0; scratch_index < 6; ++scratch_index) {
|
||||
node->temporaries->data[scratch_index] =
|
||||
op_data->scratch_tensor_index + scratch_index;
|
||||
TfLiteTensor* scratch_tensor =
|
||||
GetTemporary(context, node, scratch_index);
|
||||
TfLiteTensor* scratch_tensor;
|
||||
TF_LITE_ENSURE_OK(
|
||||
context,
|
||||
GetTemporarySafe(context, node, scratch_index, &scratch_tensor));
|
||||
scratch_tensor->type = kTfLiteInt16;
|
||||
if (scratch_index == 4) {
|
||||
scratch_tensor->type = kTfLiteInt8;
|
||||
@ -1701,8 +1794,10 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
for (int scratch_index = 0; scratch_index < 8; ++scratch_index) {
|
||||
node->temporaries->data[scratch_index] =
|
||||
op_data->scratch_tensor_index + scratch_index;
|
||||
TfLiteTensor* scratch_tensor =
|
||||
GetTemporary(context, node, scratch_index);
|
||||
TfLiteTensor* scratch_tensor;
|
||||
TF_LITE_ENSURE_OK(
|
||||
context,
|
||||
GetTemporarySafe(context, node, scratch_index, &scratch_tensor));
|
||||
if (scratch_index == 0 || scratch_index == 1) {
|
||||
scratch_tensor->type = kTfLiteInt8;
|
||||
} else {
|
||||
@ -1731,25 +1826,38 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
||||
const auto* params = static_cast<TfLiteLSTMParams*>(node->builtin_data);
|
||||
OpData* op_data = static_cast<OpData*>(node->user_data);
|
||||
|
||||
const TfLiteTensor* input = GetInput(context, node, kInputTensor);
|
||||
const TfLiteTensor* input;
|
||||
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputTensor, &input));
|
||||
|
||||
const TfLiteTensor* input_to_input_weights =
|
||||
GetOptionalInputTensor(context, node, kInputToInputWeightsTensor);
|
||||
const TfLiteTensor* input_to_forget_weights =
|
||||
GetInput(context, node, kInputToForgetWeightsTensor);
|
||||
const TfLiteTensor* input_to_cell_weights =
|
||||
GetInput(context, node, kInputToCellWeightsTensor);
|
||||
const TfLiteTensor* input_to_output_weights =
|
||||
GetInput(context, node, kInputToOutputWeightsTensor);
|
||||
const TfLiteTensor* input_to_forget_weights;
|
||||
TF_LITE_ENSURE_OK(context,
|
||||
GetInputSafe(context, node, kInputToForgetWeightsTensor,
|
||||
&input_to_forget_weights));
|
||||
const TfLiteTensor* input_to_cell_weights;
|
||||
TF_LITE_ENSURE_OK(context,
|
||||
GetInputSafe(context, node, kInputToCellWeightsTensor,
|
||||
&input_to_cell_weights));
|
||||
const TfLiteTensor* input_to_output_weights;
|
||||
TF_LITE_ENSURE_OK(context,
|
||||
GetInputSafe(context, node, kInputToOutputWeightsTensor,
|
||||
&input_to_output_weights));
|
||||
|
||||
const TfLiteTensor* recurrent_to_input_weights =
|
||||
GetOptionalInputTensor(context, node, kRecurrentToInputWeightsTensor);
|
||||
const TfLiteTensor* recurrent_to_forget_weights =
|
||||
GetInput(context, node, kRecurrentToForgetWeightsTensor);
|
||||
const TfLiteTensor* recurrent_to_cell_weights =
|
||||
GetInput(context, node, kRecurrentToCellWeightsTensor);
|
||||
const TfLiteTensor* recurrent_to_output_weights =
|
||||
GetInput(context, node, kRecurrentToOutputWeightsTensor);
|
||||
const TfLiteTensor* recurrent_to_forget_weights;
|
||||
TF_LITE_ENSURE_OK(context,
|
||||
GetInputSafe(context, node, kRecurrentToForgetWeightsTensor,
|
||||
&recurrent_to_forget_weights));
|
||||
const TfLiteTensor* recurrent_to_cell_weights;
|
||||
TF_LITE_ENSURE_OK(context,
|
||||
GetInputSafe(context, node, kRecurrentToCellWeightsTensor,
|
||||
&recurrent_to_cell_weights));
|
||||
const TfLiteTensor* recurrent_to_output_weights;
|
||||
TF_LITE_ENSURE_OK(context,
|
||||
GetInputSafe(context, node, kRecurrentToOutputWeightsTensor,
|
||||
&recurrent_to_output_weights));
|
||||
|
||||
const TfLiteTensor* cell_to_input_weights =
|
||||
GetOptionalInputTensor(context, node, kCellToInputWeightsTensor);
|
||||
@ -1769,12 +1877,15 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
||||
|
||||
const TfLiteTensor* input_gate_bias =
|
||||
GetOptionalInputTensor(context, node, kInputGateBiasTensor);
|
||||
const TfLiteTensor* forget_gate_bias =
|
||||
GetInput(context, node, kForgetGateBiasTensor);
|
||||
const TfLiteTensor* cell_gate_bias =
|
||||
GetInput(context, node, kCellGateBiasTensor);
|
||||
const TfLiteTensor* output_gate_bias =
|
||||
GetInput(context, node, kOutputGateBiasTensor);
|
||||
const TfLiteTensor* forget_gate_bias;
|
||||
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kForgetGateBiasTensor,
|
||||
&forget_gate_bias));
|
||||
const TfLiteTensor* cell_gate_bias;
|
||||
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kCellGateBiasTensor,
|
||||
&cell_gate_bias));
|
||||
const TfLiteTensor* output_gate_bias;
|
||||
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kOutputGateBiasTensor,
|
||||
&output_gate_bias));
|
||||
|
||||
const TfLiteTensor* projection_weights =
|
||||
GetOptionalInputTensor(context, node, kProjectionWeightsTensor);
|
||||
@ -1783,16 +1894,20 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
||||
|
||||
TfLiteTensor* output_state =
|
||||
GetVariableInput(context, node, kOutputStateTensor);
|
||||
TF_LITE_ENSURE(context, output_state != nullptr);
|
||||
TFLITE_DCHECK(output_state != nullptr);
|
||||
TfLiteTensor* cell_state = GetVariableInput(context, node, kCellStateTensor);
|
||||
TF_LITE_ENSURE(context, cell_state != nullptr);
|
||||
TFLITE_DCHECK(cell_state != nullptr);
|
||||
|
||||
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
|
||||
TfLiteTensor* output;
|
||||
TF_LITE_ENSURE_OK(context,
|
||||
GetOutputSafe(context, node, kOutputTensor, &output));
|
||||
|
||||
switch (input_to_output_weights->type) {
|
||||
case kTfLiteFloat32: {
|
||||
// Index the scratch buffers pointers to the global scratch buffer.
|
||||
TfLiteTensor* scratch_buffer = GetTemporary(context, node, 0);
|
||||
TfLiteTensor* scratch_buffer;
|
||||
TF_LITE_ENSURE_OK(context,
|
||||
GetTemporarySafe(context, node, 0, &scratch_buffer));
|
||||
return lstm_eval::EvalFloat(
|
||||
input, input_to_input_weights, input_to_forget_weights,
|
||||
input_to_cell_weights, input_to_output_weights,
|
||||
@ -1818,7 +1933,9 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
||||
const bool is_hybrid = (input->type == kTfLiteFloat32);
|
||||
const bool is_sparse = input_to_output_weights->sparsity != nullptr;
|
||||
if (is_hybrid) {
|
||||
TfLiteTensor* row_sums = GetTemporary(context, node, kRowSums);
|
||||
TfLiteTensor* row_sums;
|
||||
TF_LITE_ENSURE_OK(context,
|
||||
GetTemporarySafe(context, node, kRowSums, &row_sums));
|
||||
const int row_sums_size = row_sums->dims->data[0];
|
||||
if (is_sparse) {
|
||||
TfLiteTensor* input_to_input_weights_ledger =
|
||||
@ -1957,12 +2074,24 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
||||
} else {
|
||||
const int num_intermediate_tensors = node->intermediates->size;
|
||||
if (num_intermediate_tensors == 5) {
|
||||
TfLiteTensor* scratch0 = GetTemporary(context, node, 0);
|
||||
TfLiteTensor* scratch1 = GetTemporary(context, node, 1);
|
||||
TfLiteTensor* scratch2 = GetTemporary(context, node, 2);
|
||||
TfLiteTensor* scratch3 = GetTemporary(context, node, 3);
|
||||
TfLiteTensor* scratch4 = GetTemporary(context, node, 4);
|
||||
TfLiteTensor* scratch5 = GetTemporary(context, node, 5);
|
||||
TfLiteTensor* scratch0;
|
||||
TF_LITE_ENSURE_OK(context,
|
||||
GetTemporarySafe(context, node, 0, &scratch0));
|
||||
TfLiteTensor* scratch1;
|
||||
TF_LITE_ENSURE_OK(context,
|
||||
GetTemporarySafe(context, node, 1, &scratch1));
|
||||
TfLiteTensor* scratch2;
|
||||
TF_LITE_ENSURE_OK(context,
|
||||
GetTemporarySafe(context, node, 2, &scratch2));
|
||||
TfLiteTensor* scratch3;
|
||||
TF_LITE_ENSURE_OK(context,
|
||||
GetTemporarySafe(context, node, 3, &scratch3));
|
||||
TfLiteTensor* scratch4;
|
||||
TF_LITE_ENSURE_OK(context,
|
||||
GetTemporarySafe(context, node, 4, &scratch4));
|
||||
TfLiteTensor* scratch5;
|
||||
TF_LITE_ENSURE_OK(context,
|
||||
GetTemporarySafe(context, node, 5, &scratch5));
|
||||
return lstm_eval::EvalInteger8x8_16(
|
||||
input, input_to_input_weights, input_to_forget_weights,
|
||||
input_to_cell_weights, input_to_output_weights,
|
||||
@ -1978,14 +2107,30 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
||||
scratch3, scratch4, scratch5,
|
||||
CpuBackendContext::GetFromContext(context));
|
||||
} else {
|
||||
TfLiteTensor* scratch0 = GetTemporary(context, node, 0);
|
||||
TfLiteTensor* scratch1 = GetTemporary(context, node, 1);
|
||||
TfLiteTensor* scratch2 = GetTemporary(context, node, 2);
|
||||
TfLiteTensor* scratch3 = GetTemporary(context, node, 3);
|
||||
TfLiteTensor* scratch4 = GetTemporary(context, node, 4);
|
||||
TfLiteTensor* scratch5 = GetTemporary(context, node, 5);
|
||||
TfLiteTensor* scratch6 = GetTemporary(context, node, 6);
|
||||
TfLiteTensor* scratch7 = GetTemporary(context, node, 7);
|
||||
TfLiteTensor* scratch0;
|
||||
TF_LITE_ENSURE_OK(context,
|
||||
GetTemporarySafe(context, node, 0, &scratch0));
|
||||
TfLiteTensor* scratch1;
|
||||
TF_LITE_ENSURE_OK(context,
|
||||
GetTemporarySafe(context, node, 1, &scratch1));
|
||||
TfLiteTensor* scratch2;
|
||||
TF_LITE_ENSURE_OK(context,
|
||||
GetTemporarySafe(context, node, 2, &scratch2));
|
||||
TfLiteTensor* scratch3;
|
||||
TF_LITE_ENSURE_OK(context,
|
||||
GetTemporarySafe(context, node, 3, &scratch3));
|
||||
TfLiteTensor* scratch4;
|
||||
TF_LITE_ENSURE_OK(context,
|
||||
GetTemporarySafe(context, node, 4, &scratch4));
|
||||
TfLiteTensor* scratch5;
|
||||
TF_LITE_ENSURE_OK(context,
|
||||
GetTemporarySafe(context, node, 5, &scratch5));
|
||||
TfLiteTensor* scratch6;
|
||||
TF_LITE_ENSURE_OK(context,
|
||||
GetTemporarySafe(context, node, 6, &scratch6));
|
||||
TfLiteTensor* scratch7;
|
||||
TF_LITE_ENSURE_OK(context,
|
||||
GetTemporarySafe(context, node, 7, &scratch7));
|
||||
return lstm_eval::EvalInteger8x8_8(
|
||||
input, input_to_input_weights, input_to_forget_weights,
|
||||
input_to_cell_weights, input_to_output_weights,
|
||||
@ -2046,12 +2191,19 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
TF_LITE_ENSURE(context, node->inputs->size == kInputNum);
|
||||
TF_LITE_ENSURE(context, node->outputs->size == kOutputNum);
|
||||
|
||||
const TfLiteTensor* input = GetInput(context, node, kInputData);
|
||||
const TfLiteTensor* prev_activation =
|
||||
GetInput(context, node, kInputPrevActivation);
|
||||
const TfLiteTensor* weights = GetInput(context, node, kInputWeights);
|
||||
const TfLiteTensor* bias = GetInput(context, node, kInputBiases);
|
||||
const TfLiteTensor* prev_state = GetInput(context, node, kInputPrevState);
|
||||
const TfLiteTensor* input;
|
||||
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputData, &input));
|
||||
const TfLiteTensor* prev_activation;
|
||||
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputPrevActivation,
|
||||
&prev_activation));
|
||||
const TfLiteTensor* weights;
|
||||
TF_LITE_ENSURE_OK(context,
|
||||
GetInputSafe(context, node, kInputWeights, &weights));
|
||||
const TfLiteTensor* bias;
|
||||
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputBiases, &bias));
|
||||
const TfLiteTensor* prev_state;
|
||||
TF_LITE_ENSURE_OK(context,
|
||||
GetInputSafe(context, node, kInputPrevState, &prev_state));
|
||||
|
||||
TF_LITE_ENSURE_EQ(context, input->dims->size, 2);
|
||||
const int num_batches = input->dims->data[0];
|
||||
@ -2073,11 +2225,18 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
TF_LITE_ENSURE_EQ(context, prev_state->dims->data[0], num_batches);
|
||||
TF_LITE_ENSURE_EQ(context, prev_state->dims->data[1], activation_depth);
|
||||
|
||||
TfLiteTensor* activation_out = GetOutput(context, node, kOutputActivation);
|
||||
TfLiteTensor* state_out = GetOutput(context, node, kOutputState);
|
||||
TfLiteTensor* concat_temp = GetOutput(context, node, kOutputConcatTemp);
|
||||
TfLiteTensor* activation_temp =
|
||||
GetOutput(context, node, kOutputActivationTemp);
|
||||
TfLiteTensor* activation_out;
|
||||
TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, kOutputActivation,
|
||||
&activation_out));
|
||||
TfLiteTensor* state_out;
|
||||
TF_LITE_ENSURE_OK(context,
|
||||
GetOutputSafe(context, node, kOutputState, &state_out));
|
||||
TfLiteTensor* concat_temp;
|
||||
TF_LITE_ENSURE_OK(
|
||||
context, GetOutputSafe(context, node, kOutputConcatTemp, &concat_temp));
|
||||
TfLiteTensor* activation_temp;
|
||||
TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, kOutputActivationTemp,
|
||||
&activation_temp));
|
||||
|
||||
TF_LITE_ENSURE_OK(context, context->ResizeTensor(
|
||||
context, activation_out,
|
||||
@ -2106,18 +2265,32 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
}
|
||||
|
||||
TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
||||
const TfLiteTensor* input = GetInput(context, node, kInputData);
|
||||
const TfLiteTensor* prev_activation =
|
||||
GetInput(context, node, kInputPrevActivation);
|
||||
const TfLiteTensor* weights = GetInput(context, node, kInputWeights);
|
||||
const TfLiteTensor* bias = GetInput(context, node, kInputBiases);
|
||||
const TfLiteTensor* prev_state = GetInput(context, node, kInputPrevState);
|
||||
const TfLiteTensor* input;
|
||||
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputData, &input));
|
||||
const TfLiteTensor* prev_activation;
|
||||
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputPrevActivation,
|
||||
&prev_activation));
|
||||
const TfLiteTensor* weights;
|
||||
TF_LITE_ENSURE_OK(context,
|
||||
GetInputSafe(context, node, kInputWeights, &weights));
|
||||
const TfLiteTensor* bias;
|
||||
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputBiases, &bias));
|
||||
const TfLiteTensor* prev_state;
|
||||
TF_LITE_ENSURE_OK(context,
|
||||
GetInputSafe(context, node, kInputPrevState, &prev_state));
|
||||
|
||||
TfLiteTensor* activation_out = GetOutput(context, node, kOutputActivation);
|
||||
TfLiteTensor* state_out = GetOutput(context, node, kOutputState);
|
||||
TfLiteTensor* concat_temp = GetOutput(context, node, kOutputConcatTemp);
|
||||
TfLiteTensor* activation_temp =
|
||||
GetOutput(context, node, kOutputActivationTemp);
|
||||
TfLiteTensor* activation_out;
|
||||
TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, kOutputActivation,
|
||||
&activation_out));
|
||||
TfLiteTensor* state_out;
|
||||
TF_LITE_ENSURE_OK(context,
|
||||
GetOutputSafe(context, node, kOutputState, &state_out));
|
||||
TfLiteTensor* concat_temp;
|
||||
TF_LITE_ENSURE_OK(
|
||||
context, GetOutputSafe(context, node, kOutputConcatTemp, &concat_temp));
|
||||
TfLiteTensor* activation_temp;
|
||||
TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, kOutputActivationTemp,
|
||||
&activation_temp));
|
||||
|
||||
if (input->type == kTfLiteFloat32 &&
|
||||
prev_activation->type == kTfLiteFloat32 &&
|
||||
|
@ -32,12 +32,15 @@ constexpr int kOutputTensor = 0;
|
||||
TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
TF_LITE_ENSURE_EQ(context, NumInputs(node), 1);
|
||||
TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
|
||||
const TfLiteTensor* input = GetInput(context, node, kInputTensor);
|
||||
const TfLiteTensor* input;
|
||||
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputTensor, &input));
|
||||
TfLiteIntArray* input_dims = input->dims;
|
||||
int input_dims_size = input_dims->size;
|
||||
TF_LITE_ENSURE(context, input_dims_size >= 1);
|
||||
|
||||
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
|
||||
TfLiteTensor* output;
|
||||
TF_LITE_ENSURE_OK(context,
|
||||
GetOutputSafe(context, node, kOutputTensor, &output));
|
||||
// Resize the output tensor.
|
||||
TfLiteIntArray* output_shape = TfLiteIntArrayCreate(input_dims_size + 1);
|
||||
for (int i = 0; i < input_dims_size; i++) {
|
||||
@ -116,8 +119,11 @@ void FillDiagHelper(const TfLiteTensor* input, TfLiteTensor* output) {
|
||||
}
|
||||
|
||||
TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
||||
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
|
||||
const TfLiteTensor* input = GetInput(context, node, kInputTensor);
|
||||
TfLiteTensor* output;
|
||||
TF_LITE_ENSURE_OK(context,
|
||||
GetOutputSafe(context, node, kOutputTensor, &output));
|
||||
const TfLiteTensor* input;
|
||||
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputTensor, &input));
|
||||
FillDiagHelper(input, output);
|
||||
return kTfLiteOk;
|
||||
}
|
||||
|
@ -33,12 +33,15 @@ constexpr int kOutputTensor = 0;
|
||||
TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
TF_LITE_ENSURE_EQ(context, NumInputs(node), 2);
|
||||
TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
|
||||
const TfLiteTensor* input = GetInput(context, node, kInputTensor);
|
||||
const TfLiteTensor* input;
|
||||
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputTensor, &input));
|
||||
TfLiteIntArray* input_dims = input->dims;
|
||||
int input_dims_size = input_dims->size;
|
||||
TF_LITE_ENSURE(context, input_dims_size >= 2);
|
||||
|
||||
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
|
||||
TfLiteTensor* output;
|
||||
TF_LITE_ENSURE_OK(context,
|
||||
GetOutputSafe(context, node, kOutputTensor, &output));
|
||||
|
||||
TfLiteIntArray* output_shape = TfLiteIntArrayCreate(input_dims_size);
|
||||
for (int i = 0; i < input_dims_size; i++) {
|
||||
@ -126,9 +129,14 @@ void FillDiagHelper(const TfLiteTensor* input, const TfLiteTensor* diag,
|
||||
}
|
||||
|
||||
TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
||||
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
|
||||
const TfLiteTensor* input = GetInput(context, node, kInputTensor);
|
||||
const TfLiteTensor* diag = GetInput(context, node, kDiagonalTensor);
|
||||
TfLiteTensor* output;
|
||||
TF_LITE_ENSURE_OK(context,
|
||||
GetOutputSafe(context, node, kOutputTensor, &output));
|
||||
const TfLiteTensor* input;
|
||||
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputTensor, &input));
|
||||
const TfLiteTensor* diag;
|
||||
TF_LITE_ENSURE_OK(context,
|
||||
GetInputSafe(context, node, kDiagonalTensor, &diag));
|
||||
FillDiagHelper(input, diag, output);
|
||||
return kTfLiteOk;
|
||||
}
|
||||
|
@ -73,9 +73,15 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
TF_LITE_ENSURE_EQ(context, NumInputs(node), 2);
|
||||
TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
|
||||
|
||||
const TfLiteTensor* input_wav = GetInput(context, node, kInputTensorWav);
|
||||
const TfLiteTensor* input_rate = GetInput(context, node, kInputTensorRate);
|
||||
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
|
||||
const TfLiteTensor* input_wav;
|
||||
TF_LITE_ENSURE_OK(context,
|
||||
GetInputSafe(context, node, kInputTensorWav, &input_wav));
|
||||
const TfLiteTensor* input_rate;
|
||||
TF_LITE_ENSURE_OK(context,
|
||||
GetInputSafe(context, node, kInputTensorRate, &input_rate));
|
||||
TfLiteTensor* output;
|
||||
TF_LITE_ENSURE_OK(context,
|
||||
GetOutputSafe(context, node, kOutputTensor, &output));
|
||||
|
||||
TF_LITE_ENSURE_EQ(context, NumDimensions(input_wav), 3);
|
||||
TF_LITE_ENSURE_EQ(context, NumElements(input_rate), 1);
|
||||
@ -101,9 +107,15 @@ template <KernelType kernel_type>
|
||||
TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
||||
auto* params = reinterpret_cast<TfLiteMfccParams*>(node->user_data);
|
||||
|
||||
const TfLiteTensor* input_wav = GetInput(context, node, kInputTensorWav);
|
||||
const TfLiteTensor* input_rate = GetInput(context, node, kInputTensorRate);
|
||||
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
|
||||
const TfLiteTensor* input_wav;
|
||||
TF_LITE_ENSURE_OK(context,
|
||||
GetInputSafe(context, node, kInputTensorWav, &input_wav));
|
||||
const TfLiteTensor* input_rate;
|
||||
TF_LITE_ENSURE_OK(context,
|
||||
GetInputSafe(context, node, kInputTensorRate, &input_rate));
|
||||
TfLiteTensor* output;
|
||||
TF_LITE_ENSURE_OK(context,
|
||||
GetOutputSafe(context, node, kOutputTensor, &output));
|
||||
|
||||
const int32 sample_rate = *GetTensorData<int>(input_rate);
|
||||
|
||||
|
@ -162,8 +162,10 @@ struct MirrorPadWorkerTask : cpu_backend_threadpool::Task {
|
||||
|
||||
TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
||||
ruy::profiler::ScopeLabel label("MirrorPad");
|
||||
const TfLiteTensor* input_tensor = GetInput(context, node, 0);
|
||||
const TfLiteTensor* padding_matrix = GetInput(context, node, 1);
|
||||
const TfLiteTensor* input_tensor;
|
||||
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 0, &input_tensor));
|
||||
const TfLiteTensor* padding_matrix;
|
||||
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 1, &padding_matrix));
|
||||
auto* params =
|
||||
reinterpret_cast<TfLiteMirrorPaddingParams*>(node->builtin_data);
|
||||
|
||||
@ -172,7 +174,8 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
||||
}
|
||||
const int input_dims = NumDimensions(input_tensor);
|
||||
|
||||
TfLiteTensor* output_tensor = GetOutput(context, node, 0);
|
||||
TfLiteTensor* output_tensor;
|
||||
TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, 0, &output_tensor));
|
||||
if (IsDynamicTensor(output_tensor)) {
|
||||
auto output_size = GetPaddedOutputShape(input_tensor, padding_matrix);
|
||||
if (output_size == nullptr) {
|
||||
@ -258,9 +261,12 @@ void* Init(TfLiteContext* context, const char* buffer, size_t length) {
|
||||
void Free(TfLiteContext* context, void* buffer) {}
|
||||
|
||||
TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
const TfLiteTensor* input_tensor = GetInput(context, node, 0);
|
||||
const TfLiteTensor* padding_matrix = GetInput(context, node, 1);
|
||||
TfLiteTensor* output_tensor = GetOutput(context, node, 0);
|
||||
const TfLiteTensor* input_tensor;
|
||||
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 0, &input_tensor));
|
||||
const TfLiteTensor* padding_matrix;
|
||||
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 1, &padding_matrix));
|
||||
TfLiteTensor* output_tensor;
|
||||
TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, 0, &output_tensor));
|
||||
|
||||
TF_LITE_ENSURE_EQ(context, NumDimensions(padding_matrix), 2);
|
||||
TF_LITE_ENSURE_EQ(context, SizeOfDimension(padding_matrix, 0),
|
||||
|
@ -75,9 +75,15 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
TF_LITE_ENSURE_EQ(context, NumInputs(node), 2);
|
||||
TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
|
||||
|
||||
const TfLiteTensor* input1 = GetInput(context, node, kInputTensor1);
|
||||
const TfLiteTensor* input2 = GetInput(context, node, kInputTensor2);
|
||||
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
|
||||
const TfLiteTensor* input1;
|
||||
TF_LITE_ENSURE_OK(context,
|
||||
GetInputSafe(context, node, kInputTensor1, &input1));
|
||||
const TfLiteTensor* input2;
|
||||
TF_LITE_ENSURE_OK(context,
|
||||
GetInputSafe(context, node, kInputTensor2, &input2));
|
||||
TfLiteTensor* output;
|
||||
TF_LITE_ENSURE_OK(context,
|
||||
GetOutputSafe(context, node, kOutputTensor, &output));
|
||||
|
||||
TF_LITE_ENSURE_TYPES_EQ(context, input1->type, input2->type);
|
||||
|
||||
@ -259,9 +265,15 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
||||
auto* params = reinterpret_cast<TfLiteMulParams*>(node->builtin_data);
|
||||
OpData* data = reinterpret_cast<OpData*>(node->user_data);
|
||||
|
||||
const TfLiteTensor* input1 = GetInput(context, node, kInputTensor1);
|
||||
const TfLiteTensor* input2 = GetInput(context, node, kInputTensor2);
|
||||
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
|
||||
const TfLiteTensor* input1;
|
||||
TF_LITE_ENSURE_OK(context,
|
||||
GetInputSafe(context, node, kInputTensor1, &input1));
|
||||
const TfLiteTensor* input2;
|
||||
TF_LITE_ENSURE_OK(context,
|
||||
GetInputSafe(context, node, kInputTensor2, &input2));
|
||||
TfLiteTensor* output;
|
||||
TF_LITE_ENSURE_OK(context,
|
||||
GetOutputSafe(context, node, kOutputTensor, &output));
|
||||
|
||||
if (output->type == kTfLiteFloat32 || output->type == kTfLiteInt32) {
|
||||
EvalMul<kernel_type>(context, node, params, data, input1, input2, output);
|
||||
|
@ -34,8 +34,11 @@ constexpr int kOutputTensor = 0;
|
||||
TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
TF_LITE_ENSURE_EQ(context, NumInputs(node), 1);
|
||||
TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
|
||||
const TfLiteTensor* input = GetInput(context, node, kInputTensor);
|
||||
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
|
||||
const TfLiteTensor* input;
|
||||
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputTensor, &input));
|
||||
TfLiteTensor* output;
|
||||
TF_LITE_ENSURE_OK(context,
|
||||
GetOutputSafe(context, node, kOutputTensor, &output));
|
||||
|
||||
output->type = input->type;
|
||||
return context->ResizeTensor(context, output,
|
||||
@ -43,8 +46,11 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
}
|
||||
|
||||
TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
||||
const TfLiteTensor* input = GetInput(context, node, kInputTensor);
|
||||
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
|
||||
const TfLiteTensor* input;
|
||||
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputTensor, &input));
|
||||
TfLiteTensor* output;
|
||||
TF_LITE_ENSURE_OK(context,
|
||||
GetOutputSafe(context, node, kOutputTensor, &output));
|
||||
switch (input->type) {
|
||||
case kTfLiteInt64:
|
||||
reference_ops::Negate(
|
||||
|
@ -79,20 +79,25 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
}
|
||||
|
||||
// Boxes & Scores.
|
||||
const TfLiteTensor* input_boxes = GetInput(context, node, kInputTensorBoxes);
|
||||
const TfLiteTensor* input_boxes;
|
||||
TF_LITE_ENSURE_OK(
|
||||
context, GetInputSafe(context, node, kInputTensorBoxes, &input_boxes));
|
||||
TF_LITE_ENSURE_EQ(context, input_boxes->type, kTfLiteFloat32);
|
||||
TF_LITE_ENSURE_EQ(context, NumDimensions(input_boxes), 2);
|
||||
TF_LITE_ENSURE_EQ(context, SizeOfDimension(input_boxes, 1), 4);
|
||||
const int num_boxes = SizeOfDimension(input_boxes, 0);
|
||||
const TfLiteTensor* input_scores =
|
||||
GetInput(context, node, kInputTensorScores);
|
||||
const TfLiteTensor* input_scores;
|
||||
TF_LITE_ENSURE_OK(
|
||||
context, GetInputSafe(context, node, kInputTensorScores, &input_scores));
|
||||
TF_LITE_ENSURE_EQ(context, input_scores->type, kTfLiteFloat32);
|
||||
TF_LITE_ENSURE_EQ(context, NumDimensions(input_scores), 1);
|
||||
TF_LITE_ENSURE_EQ(context, num_boxes, SizeOfDimension(input_scores, 0));
|
||||
|
||||
// Max output size.
|
||||
const TfLiteTensor* input_max_output_size =
|
||||
GetInput(context, node, kInputTensorMaxOutputSize);
|
||||
const TfLiteTensor* input_max_output_size;
|
||||
TF_LITE_ENSURE_OK(context,
|
||||
GetInputSafe(context, node, kInputTensorMaxOutputSize,
|
||||
&input_max_output_size));
|
||||
TF_LITE_ENSURE_EQ(context, input_max_output_size->type, kTfLiteInt32);
|
||||
TF_LITE_ENSURE_EQ(context, NumDimensions(input_max_output_size), 0);
|
||||
const bool is_max_output_size_const = IsConstantTensor(input_max_output_size);
|
||||
@ -103,30 +108,43 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
}
|
||||
|
||||
// IoU & Score thresholds.
|
||||
const TfLiteTensor* input_iou_threshold =
|
||||
GetInput(context, node, kInputTensorIouThreshold);
|
||||
const TfLiteTensor* input_iou_threshold;
|
||||
TF_LITE_ENSURE_OK(context,
|
||||
GetInputSafe(context, node, kInputTensorIouThreshold,
|
||||
&input_iou_threshold));
|
||||
TF_LITE_ENSURE_EQ(context, input_iou_threshold->type, kTfLiteFloat32);
|
||||
TF_LITE_ENSURE_EQ(context, NumDimensions(input_iou_threshold), 0);
|
||||
const TfLiteTensor* input_score_threshold =
|
||||
GetInput(context, node, kInputTensorScoreThreshold);
|
||||
const TfLiteTensor* input_score_threshold;
|
||||
TF_LITE_ENSURE_OK(context,
|
||||
GetInputSafe(context, node, kInputTensorScoreThreshold,
|
||||
&input_score_threshold));
|
||||
TF_LITE_ENSURE_EQ(context, input_iou_threshold->type, kTfLiteFloat32);
|
||||
TF_LITE_ENSURE_EQ(context, NumDimensions(input_score_threshold), 0);
|
||||
|
||||
if (is_soft_nms) {
|
||||
const TfLiteTensor* input_sigma =
|
||||
GetInput(context, node, kInputTensorSigma);
|
||||
const TfLiteTensor* input_sigma;
|
||||
TF_LITE_ENSURE_OK(
|
||||
context, GetInputSafe(context, node, kInputTensorSigma, &input_sigma));
|
||||
TF_LITE_ENSURE_EQ(context, input_sigma->type, kTfLiteFloat32);
|
||||
TF_LITE_ENSURE_EQ(context, NumDimensions(input_sigma), 0);
|
||||
|
||||
TF_LITE_ENSURE_EQ(context, NumOutputs(node), 3);
|
||||
TfLiteTensor* output_selected_indices =
|
||||
GetOutput(context, node, kSoftNMSOutputTensorSelectedIndices);
|
||||
TfLiteTensor* output_selected_indices;
|
||||
TF_LITE_ENSURE_OK(
|
||||
context,
|
||||
GetOutputSafe(context, node, kSoftNMSOutputTensorSelectedIndices,
|
||||
&output_selected_indices));
|
||||
output_selected_indices->type = kTfLiteInt32;
|
||||
TfLiteTensor* output_selected_scores =
|
||||
GetOutput(context, node, kSoftNMSOutputTensorSelectedScores);
|
||||
TfLiteTensor* output_selected_scores;
|
||||
TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node,
|
||||
kSoftNMSOutputTensorSelectedScores,
|
||||
&output_selected_scores));
|
||||
output_selected_scores->type = kTfLiteFloat32;
|
||||
TfLiteTensor* output_num_selected_indices =
|
||||
GetOutput(context, node, kSoftNMSOutputTensorNumSelectedIndices);
|
||||
TfLiteTensor* output_num_selected_indices;
|
||||
TF_LITE_ENSURE_OK(
|
||||
context,
|
||||
GetOutputSafe(context, node, kSoftNMSOutputTensorNumSelectedIndices,
|
||||
&output_num_selected_indices));
|
||||
output_num_selected_indices->type = kTfLiteInt32;
|
||||
SetTensorSizes(context, output_num_selected_indices, {});
|
||||
|
||||
@ -139,11 +157,15 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
}
|
||||
} else {
|
||||
TF_LITE_ENSURE_EQ(context, NumOutputs(node), 2);
|
||||
TfLiteTensor* output_selected_indices =
|
||||
GetOutput(context, node, kNMSOutputTensorSelectedIndices);
|
||||
TfLiteTensor* output_selected_indices;
|
||||
TF_LITE_ENSURE_OK(
|
||||
context, GetOutputSafe(context, node, kNMSOutputTensorSelectedIndices,
|
||||
&output_selected_indices));
|
||||
output_selected_indices->type = kTfLiteInt32;
|
||||
TfLiteTensor* output_num_selected_indices =
|
||||
GetOutput(context, node, kNMSOutputTensorNumSelectedIndices);
|
||||
TfLiteTensor* output_num_selected_indices;
|
||||
TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node,
|
||||
kNMSOutputTensorNumSelectedIndices,
|
||||
&output_num_selected_indices));
|
||||
output_num_selected_indices->type = kTfLiteInt32;
|
||||
SetTensorSizes(context, output_num_selected_indices, {});
|
||||
|
||||
@ -179,20 +201,29 @@ void ResetUnusedElementsToZeroes(const int max_output_size,
|
||||
TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
||||
const bool is_soft_nms = NumInputs(node) == 6;
|
||||
|
||||
const TfLiteTensor* input_boxes = GetInput(context, node, kInputTensorBoxes);
|
||||
const TfLiteTensor* input_boxes;
|
||||
TF_LITE_ENSURE_OK(
|
||||
context, GetInputSafe(context, node, kInputTensorBoxes, &input_boxes));
|
||||
const int num_boxes = SizeOfDimension(input_boxes, 0);
|
||||
const TfLiteTensor* input_scores =
|
||||
GetInput(context, node, kInputTensorScores);
|
||||
const TfLiteTensor* input_max_output_size =
|
||||
GetInput(context, node, kInputTensorMaxOutputSize);
|
||||
const TfLiteTensor* input_scores;
|
||||
TF_LITE_ENSURE_OK(
|
||||
context, GetInputSafe(context, node, kInputTensorScores, &input_scores));
|
||||
const TfLiteTensor* input_max_output_size;
|
||||
TF_LITE_ENSURE_OK(context,
|
||||
GetInputSafe(context, node, kInputTensorMaxOutputSize,
|
||||
&input_max_output_size));
|
||||
const int max_output_size_value = *GetTensorData<int>(input_max_output_size);
|
||||
TF_LITE_ENSURE(context, (max_output_size_value >= 0));
|
||||
const bool is_max_output_size_const = IsConstantTensor(input_max_output_size);
|
||||
const TfLiteTensor* input_iou_threshold =
|
||||
GetInput(context, node, kInputTensorIouThreshold);
|
||||
const TfLiteTensor* input_iou_threshold;
|
||||
TF_LITE_ENSURE_OK(context,
|
||||
GetInputSafe(context, node, kInputTensorIouThreshold,
|
||||
&input_iou_threshold));
|
||||
const float iou_threshold = *GetTensorData<float>(input_iou_threshold);
|
||||
const TfLiteTensor* input_score_threshold =
|
||||
GetInput(context, node, kInputTensorScoreThreshold);
|
||||
const TfLiteTensor* input_score_threshold;
|
||||
TF_LITE_ENSURE_OK(context,
|
||||
GetInputSafe(context, node, kInputTensorScoreThreshold,
|
||||
&input_score_threshold));
|
||||
const float score_threshold = *GetTensorData<float>(input_score_threshold);
|
||||
|
||||
TfLiteTensor* output_selected_indices = nullptr;
|
||||
@ -200,8 +231,9 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
||||
TfLiteTensor* output_num_selected_indices = nullptr;
|
||||
|
||||
if (is_soft_nms) {
|
||||
const TfLiteTensor* input_sigma =
|
||||
GetInput(context, node, kInputTensorSigma);
|
||||
const TfLiteTensor* input_sigma;
|
||||
TF_LITE_ENSURE_OK(
|
||||
context, GetInputSafe(context, node, kInputTensorSigma, &input_sigma));
|
||||
const float soft_nms_sigma = *GetTensorData<float>(input_sigma);
|
||||
if (soft_nms_sigma < 0) {
|
||||
context->ReportError(context, "Invalid sigma value for soft NMS: %f",
|
||||
@ -209,12 +241,17 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
||||
return kTfLiteError;
|
||||
}
|
||||
|
||||
output_selected_indices =
|
||||
GetOutput(context, node, kSoftNMSOutputTensorSelectedIndices);
|
||||
output_selected_scores =
|
||||
GetOutput(context, node, kSoftNMSOutputTensorSelectedScores);
|
||||
output_num_selected_indices =
|
||||
GetOutput(context, node, kSoftNMSOutputTensorNumSelectedIndices);
|
||||
TF_LITE_ENSURE_OK(
|
||||
context,
|
||||
GetOutputSafe(context, node, kSoftNMSOutputTensorSelectedIndices,
|
||||
&output_selected_indices));
|
||||
TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node,
|
||||
kSoftNMSOutputTensorSelectedScores,
|
||||
&output_selected_scores));
|
||||
TF_LITE_ENSURE_OK(
|
||||
context,
|
||||
GetOutputSafe(context, node, kSoftNMSOutputTensorNumSelectedIndices,
|
||||
&output_num_selected_indices));
|
||||
if (!is_max_output_size_const) {
|
||||
SetTensorSizes(context, output_selected_indices, {max_output_size_value});
|
||||
SetTensorSizes(context, output_selected_scores, {max_output_size_value});
|
||||
@ -228,10 +265,12 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
||||
max_output_size_value, *output_num_selected_indices->data.i32,
|
||||
output_selected_indices->data.i32, output_selected_scores->data.f);
|
||||
} else {
|
||||
output_selected_indices =
|
||||
GetOutput(context, node, kNMSOutputTensorSelectedIndices);
|
||||
output_num_selected_indices =
|
||||
GetOutput(context, node, kNMSOutputTensorNumSelectedIndices);
|
||||
TF_LITE_ENSURE_OK(
|
||||
context, GetOutputSafe(context, node, kNMSOutputTensorSelectedIndices,
|
||||
&output_selected_indices));
|
||||
TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node,
|
||||
kNMSOutputTensorNumSelectedIndices,
|
||||
&output_num_selected_indices));
|
||||
if (!is_max_output_size_const) {
|
||||
SetTensorSizes(context, output_selected_indices, {max_output_size_value});
|
||||
}
|
||||
|
@ -109,7 +109,9 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
node->temporaries = TfLiteIntArrayCreate(1);
|
||||
node->temporaries->data[0] = op_data->cache_tensor_id;
|
||||
|
||||
TfLiteTensor* dequantized = GetTemporary(context, node, /*index=*/0);
|
||||
TfLiteTensor* dequantized;
|
||||
TF_LITE_ENSURE_OK(context,
|
||||
GetTemporarySafe(context, node, /*index=*/0, &dequantized));
|
||||
dequantized->type = op_context.ref->type;
|
||||
dequantized->allocation_type = kTfLiteDynamic;
|
||||
|
||||
@ -142,7 +144,9 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
||||
}
|
||||
|
||||
// Dequantize the input
|
||||
TfLiteTensor* dequantized = GetTemporary(context, node, /*index=*/0);
|
||||
TfLiteTensor* dequantized;
|
||||
TF_LITE_ENSURE_OK(context,
|
||||
GetTemporarySafe(context, node, /*index=*/0, &dequantized));
|
||||
auto status = builtin::dequantize::DequantizeImpl<kernel_type>(
|
||||
context, node, op_context.input, dequantized);
|
||||
if (status != kTfLiteOk) {
|
||||
|
@ -38,7 +38,8 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
TF_LITE_ENSURE_EQ(context, NumInputs(node), data->values_count);
|
||||
TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
|
||||
|
||||
const TfLiteTensor* input0 = GetInput(context, node, 0);
|
||||
const TfLiteTensor* input0;
|
||||
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 0, &input0));
|
||||
const int dimension_size = NumDimensions(input0) + 1;
|
||||
if (data->axis < 0) {
|
||||
data->axis += dimension_size;
|
||||
@ -55,7 +56,8 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
}
|
||||
// Make sure all inputs have the same shape and type.
|
||||
for (int i = 1; i < data->values_count; ++i) {
|
||||
const TfLiteTensor* input = GetInput(context, node, i);
|
||||
const TfLiteTensor* input;
|
||||
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, i, &input));
|
||||
TF_LITE_ENSURE(context, HaveSameShapes(input0, input));
|
||||
TF_LITE_ENSURE_TYPES_EQ(context, input0->type, input->type);
|
||||
}
|
||||
@ -72,13 +74,16 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
}
|
||||
}
|
||||
|
||||
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
|
||||
TfLiteTensor* output;
|
||||
TF_LITE_ENSURE_OK(context,
|
||||
GetOutputSafe(context, node, kOutputTensor, &output));
|
||||
TF_LITE_ENSURE_TYPES_EQ(context, output->type, input0->type);
|
||||
|
||||
// Guarantee input/output quantization params match as we do not support
|
||||
// packing quantized tensors.
|
||||
for (int i = 0; i < data->values_count; i++) {
|
||||
const TfLiteTensor* input = GetInput(context, node, i);
|
||||
const TfLiteTensor* input;
|
||||
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, i, &input));
|
||||
TF_LITE_ENSURE_EQ(context, input->params.zero_point,
|
||||
output->params.zero_point);
|
||||
TF_LITE_ENSURE_EQ(context, input->params.scale, output->params.scale);
|
||||
@ -106,7 +111,9 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
||||
const TfLitePackParams* data =
|
||||
reinterpret_cast<TfLitePackParams*>(node->builtin_data);
|
||||
|
||||
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
|
||||
TfLiteTensor* output;
|
||||
TF_LITE_ENSURE_OK(context,
|
||||
GetOutputSafe(context, node, kOutputTensor, &output));
|
||||
switch (output->type) {
|
||||
case kTfLiteFloat32: {
|
||||
return PackImpl<float>(context, node, output, data->values_count,
|
||||
|
@ -71,8 +71,10 @@ TfLiteStatus GenericPrepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
|
||||
TF_LITE_ENSURE_EQ(context, NumInputs(node), 1);
|
||||
TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
|
||||
TfLiteTensor* output = GetOutput(context, node, 0);
|
||||
const TfLiteTensor* input = GetInput(context, node, 0);
|
||||
TfLiteTensor* output;
|
||||
TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, 0, &output));
|
||||
const TfLiteTensor* input;
|
||||
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 0, &input));
|
||||
TF_LITE_ENSURE_EQ(context, NumDimensions(input), 4);
|
||||
TF_LITE_ENSURE_TYPES_EQ(context, input->type, output->type);
|
||||
|
||||
@ -368,8 +370,10 @@ TfLiteStatus AverageEval(TfLiteContext* context, TfLiteNode* node) {
|
||||
auto* params = reinterpret_cast<TfLitePoolParams*>(node->builtin_data);
|
||||
OpData* data = reinterpret_cast<OpData*>(node->user_data);
|
||||
|
||||
TfLiteTensor* output = GetOutput(context, node, 0);
|
||||
const TfLiteTensor* input = GetInput(context, node, 0);
|
||||
TfLiteTensor* output;
|
||||
TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, 0, &output));
|
||||
const TfLiteTensor* input;
|
||||
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 0, &input));
|
||||
switch (input->type) { // Already know in/out types are same.
|
||||
case kTfLiteFloat32:
|
||||
AverageEvalFloat<kernel_type>(context, node, params, data, input, output);
|
||||
@ -399,8 +403,10 @@ TfLiteStatus MaxEval(TfLiteContext* context, TfLiteNode* node) {
|
||||
auto* params = reinterpret_cast<TfLitePoolParams*>(node->builtin_data);
|
||||
OpData* data = reinterpret_cast<OpData*>(node->user_data);
|
||||
|
||||
TfLiteTensor* output = GetOutput(context, node, 0);
|
||||
const TfLiteTensor* input = GetInput(context, node, 0);
|
||||
TfLiteTensor* output;
|
||||
TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, 0, &output));
|
||||
const TfLiteTensor* input;
|
||||
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 0, &input));
|
||||
switch (input->type) { // Already know in/out types are same.
|
||||
case kTfLiteFloat32:
|
||||
MaxEvalFloat<kernel_type>(context, node, params, data, input, output);
|
||||
@ -430,8 +436,10 @@ TfLiteStatus L2Eval(TfLiteContext* context, TfLiteNode* node) {
|
||||
auto* params = reinterpret_cast<TfLitePoolParams*>(node->builtin_data);
|
||||
OpData* data = reinterpret_cast<OpData*>(node->user_data);
|
||||
|
||||
TfLiteTensor* output = GetOutput(context, node, 0);
|
||||
const TfLiteTensor* input = GetInput(context, node, 0);
|
||||
TfLiteTensor* output;
|
||||
TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, 0, &output));
|
||||
const TfLiteTensor* input;
|
||||
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 0, &input));
|
||||
switch (input->type) { // Already know in/out types are same.
|
||||
case kTfLiteFloat32:
|
||||
L2EvalFloat<kernel_type>(context, node, params, data, input, output);
|
||||
|
@ -54,9 +54,15 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
|
||||
OpData* data = reinterpret_cast<OpData*>(node->user_data);
|
||||
|
||||
const TfLiteTensor* input1 = GetInput(context, node, kInputTensor1);
|
||||
const TfLiteTensor* input2 = GetInput(context, node, kInputTensor2);
|
||||
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
|
||||
const TfLiteTensor* input1;
|
||||
TF_LITE_ENSURE_OK(context,
|
||||
GetInputSafe(context, node, kInputTensor1, &input1));
|
||||
const TfLiteTensor* input2;
|
||||
TF_LITE_ENSURE_OK(context,
|
||||
GetInputSafe(context, node, kInputTensor2, &input2));
|
||||
TfLiteTensor* output;
|
||||
TF_LITE_ENSURE_OK(context,
|
||||
GetOutputSafe(context, node, kOutputTensor, &output));
|
||||
|
||||
TF_LITE_ENSURE_TYPES_EQ(context, input1->type, input2->type);
|
||||
|
||||
@ -112,9 +118,15 @@ TfLiteStatus CheckValue(TfLiteContext* context, const TfLiteTensor* input) {
|
||||
TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
||||
OpData* data = reinterpret_cast<OpData*>(node->user_data);
|
||||
|
||||
const TfLiteTensor* input1 = GetInput(context, node, kInputTensor1);
|
||||
const TfLiteTensor* input2 = GetInput(context, node, kInputTensor2);
|
||||
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
|
||||
const TfLiteTensor* input1;
|
||||
TF_LITE_ENSURE_OK(context,
|
||||
GetInputSafe(context, node, kInputTensor1, &input1));
|
||||
const TfLiteTensor* input2;
|
||||
TF_LITE_ENSURE_OK(context,
|
||||
GetInputSafe(context, node, kInputTensor2, &input2));
|
||||
TfLiteTensor* output;
|
||||
TF_LITE_ENSURE_OK(context,
|
||||
GetOutputSafe(context, node, kOutputTensor, &output));
|
||||
|
||||
switch (output->type) {
|
||||
case kTfLiteInt32: {
|
||||
|
@ -97,8 +97,10 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
TF_LITE_ENSURE_EQ(context, NumInputs(node), 1);
|
||||
TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
|
||||
|
||||
const TfLiteTensor* input = GetInput(context, node, 0);
|
||||
TfLiteTensor* output = GetOutput(context, node, 0);
|
||||
const TfLiteTensor* input;
|
||||
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 0, &input));
|
||||
TfLiteTensor* output;
|
||||
TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, 0, &output));
|
||||
|
||||
// TODO(b/128934713): Add support for fixed-point per-channel quantization.
|
||||
// Currently this only support affine per-layer quantization.
|
||||
@ -141,8 +143,10 @@ template <KernelType kernel_type>
|
||||
TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
||||
OpData* data = static_cast<OpData*>(node->user_data);
|
||||
|
||||
const TfLiteTensor* input = GetInput(context, node, 0);
|
||||
TfLiteTensor* output = GetOutput(context, node, 0);
|
||||
const TfLiteTensor* input;
|
||||
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 0, &input));
|
||||
TfLiteTensor* output;
|
||||
TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, 0, &output));
|
||||
|
||||
const RuntimeShape input_shape = GetTensorShape(input);
|
||||
const RuntimeShape output_shape = GetTensorShape(output);
|
||||
|
@ -83,9 +83,12 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
TF_LITE_ENSURE_EQ(context, NumInputs(node), 3);
|
||||
TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
|
||||
|
||||
const TfLiteTensor* start = GetInput(context, node, kStartTensor);
|
||||
const TfLiteTensor* limit = GetInput(context, node, kLimitTensor);
|
||||
const TfLiteTensor* delta = GetInput(context, node, kDeltaTensor);
|
||||
const TfLiteTensor* start;
|
||||
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kStartTensor, &start));
|
||||
const TfLiteTensor* limit;
|
||||
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kLimitTensor, &limit));
|
||||
const TfLiteTensor* delta;
|
||||
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kDeltaTensor, &delta));
|
||||
// Make sure all the inputs are scalars.
|
||||
TF_LITE_ENSURE_EQ(context, NumDimensions(start), 0);
|
||||
TF_LITE_ENSURE_EQ(context, NumDimensions(limit), 0);
|
||||
@ -103,7 +106,9 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
TF_LITE_ENSURE_TYPES_EQ(context, limit->type, dtype);
|
||||
TF_LITE_ENSURE_TYPES_EQ(context, delta->type, dtype);
|
||||
|
||||
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
|
||||
TfLiteTensor* output;
|
||||
TF_LITE_ENSURE_OK(context,
|
||||
GetOutputSafe(context, node, kOutputTensor, &output));
|
||||
output->type = dtype;
|
||||
|
||||
if (IsConstantTensor(start) && IsConstantTensor(limit) &&
|
||||
@ -130,11 +135,16 @@ void EvalImpl(const TfLiteTensor* start, const TfLiteTensor* delta,
|
||||
}
|
||||
|
||||
TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
||||
const TfLiteTensor* start = GetInput(context, node, kStartTensor);
|
||||
const TfLiteTensor* limit = GetInput(context, node, kLimitTensor);
|
||||
const TfLiteTensor* delta = GetInput(context, node, kDeltaTensor);
|
||||
const TfLiteTensor* start;
|
||||
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kStartTensor, &start));
|
||||
const TfLiteTensor* limit;
|
||||
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kLimitTensor, &limit));
|
||||
const TfLiteTensor* delta;
|
||||
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kDeltaTensor, &delta));
|
||||
|
||||
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
|
||||
TfLiteTensor* output;
|
||||
TF_LITE_ENSURE_OK(context,
|
||||
GetOutputSafe(context, node, kOutputTensor, &output));
|
||||
|
||||
if (IsDynamicTensor(output)) {
|
||||
TF_LITE_ENSURE_OK(context,
|
||||
|
@ -31,8 +31,11 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
TF_LITE_ENSURE_EQ(context, NumInputs(node), 1);
|
||||
TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
|
||||
|
||||
const TfLiteTensor* input = GetInput(context, node, kInputTensor);
|
||||
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
|
||||
const TfLiteTensor* input;
|
||||
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputTensor, &input));
|
||||
TfLiteTensor* output;
|
||||
TF_LITE_ENSURE_OK(context,
|
||||
GetOutputSafe(context, node, kOutputTensor, &output));
|
||||
output->type = kTfLiteInt32;
|
||||
|
||||
// By design, the input shape is always known at the time of Prepare, even
|
||||
|
@ -34,12 +34,15 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
TF_LITE_ENSURE_EQ(context, node->inputs->size, 1);
|
||||
TF_LITE_ENSURE_EQ(context, node->outputs->size, 1);
|
||||
|
||||
const TfLiteTensor* input_resource_id_tensor =
|
||||
GetInput(context, node, kInputVariableId);
|
||||
const TfLiteTensor* input_resource_id_tensor;
|
||||
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputVariableId,
|
||||
&input_resource_id_tensor));
|
||||
TF_LITE_ENSURE_EQ(context, input_resource_id_tensor->type, kTfLiteInt32);
|
||||
TF_LITE_ENSURE_EQ(context, NumElements(input_resource_id_tensor), 1);
|
||||
|
||||
TfLiteTensor* output = GetOutput(context, node, kOutputValue);
|
||||
TfLiteTensor* output;
|
||||
TF_LITE_ENSURE_OK(context,
|
||||
GetOutputSafe(context, node, kOutputValue, &output));
|
||||
SetTensorToDynamic(output);
|
||||
|
||||
return kTfLiteOk;
|
||||
@ -48,15 +51,18 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
||||
Subgraph* subgraph = reinterpret_cast<Subgraph*>(context->impl_);
|
||||
|
||||
const TfLiteTensor* input_resource_id_tensor =
|
||||
GetInput(context, node, kInputVariableId);
|
||||
const TfLiteTensor* input_resource_id_tensor;
|
||||
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputVariableId,
|
||||
&input_resource_id_tensor));
|
||||
int resource_id = input_resource_id_tensor->data.i32[0];
|
||||
auto& resources = subgraph->resources();
|
||||
auto* variable = resource::GetResourceVariable(&resources, resource_id);
|
||||
TF_LITE_ENSURE(context, variable != nullptr);
|
||||
|
||||
TfLiteTensor* variable_tensor = variable->GetTensor();
|
||||
TfLiteTensor* output = GetOutput(context, node, kOutputValue);
|
||||
TfLiteTensor* output;
|
||||
TF_LITE_ENSURE_OK(context,
|
||||
GetOutputSafe(context, node, kOutputValue, &output));
|
||||
|
||||
TF_LITE_ENSURE_TYPES_EQ(context, variable_tensor->type, output->type);
|
||||
TF_LITE_ENSURE_OK(
|
||||
|
@ -170,7 +170,9 @@ TfLiteStatus InitializeTemporaries(TfLiteContext* context, TfLiteNode* node,
|
||||
TfLiteIntArrayFree(node->temporaries);
|
||||
node->temporaries = TfLiteIntArrayCreate(3);
|
||||
node->temporaries->data[0] = op_data->scratch_tensor_index;
|
||||
TfLiteTensor* scratch_tensor = GetTemporary(context, node, /*index=*/0);
|
||||
TfLiteTensor* scratch_tensor;
|
||||
TF_LITE_ENSURE_OK(
|
||||
context, GetTemporarySafe(context, node, /*index=*/0, &scratch_tensor));
|
||||
scratch_tensor->type = kTfLiteInt32;
|
||||
scratch_tensor->allocation_type = kTfLiteArenaRw;
|
||||
TfLiteIntArray* index_size = TfLiteIntArrayCreate(1);
|
||||
@ -180,11 +182,15 @@ TfLiteStatus InitializeTemporaries(TfLiteContext* context, TfLiteNode* node,
|
||||
|
||||
// Creates a temp tensor to store resolved axis given input data.
|
||||
node->temporaries->data[1] = op_data->scratch_tensor_index + 1;
|
||||
TfLiteTensor* resolved_axis = GetTemporary(context, node, /*index=*/1);
|
||||
TfLiteTensor* resolved_axis;
|
||||
TF_LITE_ENSURE_OK(
|
||||
context, GetTemporarySafe(context, node, /*index=*/1, &resolved_axis));
|
||||
resolved_axis->type = kTfLiteInt32;
|
||||
// Creates a temp tensor to store temp sums when calculating mean.
|
||||
node->temporaries->data[2] = op_data->scratch_tensor_index + 2;
|
||||
TfLiteTensor* temp_sum = GetTemporary(context, node, /*index=*/2);
|
||||
TfLiteTensor* temp_sum;
|
||||
TF_LITE_ENSURE_OK(context,
|
||||
GetTemporarySafe(context, node, /*index=*/2, &temp_sum));
|
||||
switch (op_context->input->type) {
|
||||
case kTfLiteFloat32:
|
||||
temp_sum->type = kTfLiteFloat32;
|
||||
@ -217,7 +223,9 @@ TfLiteStatus PrepareSimple(TfLiteContext* context, TfLiteNode* node) {
|
||||
TF_LITE_ENSURE_TYPES_EQ(context, op_context.axis->type, kTfLiteInt32);
|
||||
TF_LITE_ENSURE_OK(context, InitializeTemporaries(context, node, &op_context));
|
||||
|
||||
TfLiteTensor* resolved_axis = GetTemporary(context, node, /*index=*/1);
|
||||
TfLiteTensor* resolved_axis;
|
||||
TF_LITE_ENSURE_OK(
|
||||
context, GetTemporarySafe(context, node, /*index=*/1, &resolved_axis));
|
||||
// Leaves work to Eval if axis is not constant; else resizes output.
|
||||
if (!IsConstantTensor(op_context.axis)) {
|
||||
SetTensorToDynamic(op_context.output);
|
||||
@ -233,7 +241,8 @@ TfLiteStatus PrepareSimple(TfLiteContext* context, TfLiteNode* node) {
|
||||
|
||||
TfLiteStatus PrepareAny(TfLiteContext* context, TfLiteNode* node) {
|
||||
TF_LITE_ENSURE_EQ(context, NumInputs(node), 2);
|
||||
const TfLiteTensor* input = GetInput(context, node, 0);
|
||||
const TfLiteTensor* input;
|
||||
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 0, &input));
|
||||
TF_LITE_ENSURE_TYPES_EQ(context, input->type, kTfLiteBool);
|
||||
return PrepareSimple(context, node);
|
||||
}
|
||||
@ -254,7 +263,9 @@ TfLiteStatus PrepareMeanOrSum(TfLiteContext* context, TfLiteNode* node) {
|
||||
QuantizeMultiplier(real_multiplier, &data->multiplier, &exponent);
|
||||
data->shift = exponent;
|
||||
}
|
||||
TfLiteTensor* temp_sum = GetTemporary(context, node, /*index=*/2);
|
||||
TfLiteTensor* temp_sum;
|
||||
TF_LITE_ENSURE_OK(context,
|
||||
GetTemporarySafe(context, node, /*index=*/2, &temp_sum));
|
||||
if (!IsConstantTensor(op_context.axis)) {
|
||||
SetTensorToDynamic(temp_sum);
|
||||
return kTfLiteOk;
|
||||
@ -343,9 +354,15 @@ TfLiteStatus EvalMean(TfLiteContext* context, TfLiteNode* node) {
|
||||
OpData* data = reinterpret_cast<OpData*>(node->user_data);
|
||||
|
||||
int num_axis = static_cast<int>(NumElements(op_context.axis));
|
||||
TfLiteTensor* temp_index = GetTemporary(context, node, /*index=*/0);
|
||||
TfLiteTensor* resolved_axis = GetTemporary(context, node, /*index=*/1);
|
||||
TfLiteTensor* temp_sum = GetTemporary(context, node, /*index=*/2);
|
||||
TfLiteTensor* temp_index;
|
||||
TF_LITE_ENSURE_OK(context,
|
||||
GetTemporarySafe(context, node, /*index=*/0, &temp_index));
|
||||
TfLiteTensor* resolved_axis;
|
||||
TF_LITE_ENSURE_OK(
|
||||
context, GetTemporarySafe(context, node, /*index=*/1, &resolved_axis));
|
||||
TfLiteTensor* temp_sum;
|
||||
TF_LITE_ENSURE_OK(context,
|
||||
GetTemporarySafe(context, node, /*index=*/2, &temp_sum));
|
||||
// Resize the output tensor if the output tensor is dynamic.
|
||||
if (IsDynamicTensor(op_context.output)) {
|
||||
TF_LITE_ENSURE_OK(context,
|
||||
@ -490,8 +507,12 @@ TfLiteStatus EvalLogic(TfLiteContext* context, TfLiteNode* node,
|
||||
OpContext* op_context, T init_value,
|
||||
T reducer(const T current, const T in)) {
|
||||
int64_t num_axis = NumElements(op_context->axis);
|
||||
TfLiteTensor* temp_index = GetTemporary(context, node, /*index=*/0);
|
||||
TfLiteTensor* resolved_axis = GetTemporary(context, node, /*index=*/1);
|
||||
TfLiteTensor* temp_index;
|
||||
TF_LITE_ENSURE_OK(context,
|
||||
GetTemporarySafe(context, node, /*index=*/0, &temp_index));
|
||||
TfLiteTensor* resolved_axis;
|
||||
TF_LITE_ENSURE_OK(
|
||||
context, GetTemporarySafe(context, node, /*index=*/1, &resolved_axis));
|
||||
// Resize the output tensor if the output tensor is dynamic.
|
||||
if (IsDynamicTensor(op_context->output)) {
|
||||
TF_LITE_ENSURE_OK(context,
|
||||
@ -621,9 +642,15 @@ TfLiteStatus EvalSum(TfLiteContext* context, TfLiteNode* node) {
|
||||
if (need_rescale) {
|
||||
// Rescaling 8bit reduce sum.
|
||||
int num_axis = static_cast<int>(NumElements(op_context.axis));
|
||||
TfLiteTensor* temp_index = GetTemporary(context, node, /*index=*/0);
|
||||
TfLiteTensor* resolved_axis = GetTemporary(context, node, /*index=*/1);
|
||||
TfLiteTensor* temp_sum = GetTemporary(context, node, /*index=*/2);
|
||||
TfLiteTensor* temp_index;
|
||||
TF_LITE_ENSURE_OK(
|
||||
context, GetTemporarySafe(context, node, /*index=*/0, &temp_index));
|
||||
TfLiteTensor* resolved_axis;
|
||||
TF_LITE_ENSURE_OK(
|
||||
context, GetTemporarySafe(context, node, /*index=*/1, &resolved_axis));
|
||||
TfLiteTensor* temp_sum;
|
||||
TF_LITE_ENSURE_OK(context,
|
||||
GetTemporarySafe(context, node, /*index=*/2, &temp_sum));
|
||||
// Resize the output tensor if the output tensor is dynamic.
|
||||
if (IsDynamicTensor(op_context.output)) {
|
||||
TF_LITE_ENSURE_OK(context,
|
||||
|
@ -38,8 +38,11 @@ TfLiteStatus ResizeOutput(TfLiteContext* context, TfLiteNode* node) {
|
||||
std::unique_ptr<TfLiteIntArray, void (*)(TfLiteIntArray*)>
|
||||
scoped_output_shape(output_shape, TfLiteIntArrayFree);
|
||||
|
||||
const TfLiteTensor* input = GetInput(context, node, kInputTensor);
|
||||
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
|
||||
const TfLiteTensor* input;
|
||||
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputTensor, &input));
|
||||
TfLiteTensor* output;
|
||||
TF_LITE_ENSURE_OK(context,
|
||||
GetOutputSafe(context, node, kOutputTensor, &output));
|
||||
|
||||
// Tensorflow's Reshape allows one of the shape components to have the
|
||||
// special -1 value, meaning it will be calculated automatically based on the
|
||||
@ -70,6 +73,7 @@ TfLiteStatus ResizeOutput(TfLiteContext* context, TfLiteNode* node) {
|
||||
inline TfLiteIntArray* GetOutputShapeFromTensor(TfLiteContext* context,
|
||||
TfLiteNode* node) {
|
||||
const TfLiteTensor* shape = GetInput(context, node, kShapeTensor);
|
||||
if (shape == nullptr) return nullptr;
|
||||
|
||||
TfLiteIntArray* output_shape = TfLiteIntArrayCreate(shape->dims->data[0]);
|
||||
for (int i = 0; i < output_shape->size; ++i) {
|
||||
@ -103,7 +107,8 @@ inline TfLiteIntArray* GetOutputShapeFromParam(TfLiteContext* context,
|
||||
// Check if the shape tensor is valid. Shapes should be int32 vectors.
|
||||
inline bool ShapeIsVector(TfLiteContext* context, TfLiteNode* node) {
|
||||
const TfLiteTensor* shape = GetInput(context, node, kShapeTensor);
|
||||
return (shape->dims->size == 1 && shape->type == kTfLiteInt32);
|
||||
return (shape != nullptr && shape->dims->size == 1 &&
|
||||
shape->type == kTfLiteInt32);
|
||||
}
|
||||
|
||||
TfLiteIntArray* GetOutputShape(TfLiteContext* context, TfLiteNode* node) {
|
||||
@ -122,7 +127,9 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
// calculate their shapes now. String tensors don't benefit from having their
|
||||
// shapes precalculated because the actual memory can only be allocated after
|
||||
// we know all the content.
|
||||
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
|
||||
TfLiteTensor* output;
|
||||
TF_LITE_ENSURE_OK(context,
|
||||
GetOutputSafe(context, node, kOutputTensor, &output));
|
||||
if (output->type != kTfLiteString) {
|
||||
if (NumInputs(node) == 1 ||
|
||||
IsConstantTensor(GetInput(context, node, kShapeTensor))) {
|
||||
@ -135,8 +142,11 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
}
|
||||
|
||||
TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
||||
const TfLiteTensor* input = GetInput(context, node, kInputTensor);
|
||||
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
|
||||
const TfLiteTensor* input;
|
||||
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputTensor, &input));
|
||||
TfLiteTensor* output;
|
||||
TF_LITE_ENSURE_OK(context,
|
||||
GetOutputSafe(context, node, kOutputTensor, &output));
|
||||
|
||||
// There are two ways in which the 'output' can be made dynamic: it could be
|
||||
// a string tensor, or its shape cannot be calculated during Prepare(). In
|
||||
|
@ -61,9 +61,13 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
TF_LITE_ENSURE_EQ(context, NumInputs(node), 2);
|
||||
TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
|
||||
|
||||
const TfLiteTensor* input = GetInput(context, node, kInputTensor);
|
||||
const TfLiteTensor* size = GetInput(context, node, kSizeTensor);
|
||||
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
|
||||
const TfLiteTensor* input;
|
||||
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputTensor, &input));
|
||||
const TfLiteTensor* size;
|
||||
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kSizeTensor, &size));
|
||||
TfLiteTensor* output;
|
||||
TF_LITE_ENSURE_OK(context,
|
||||
GetOutputSafe(context, node, kOutputTensor, &output));
|
||||
|
||||
// TODO(ahentz): Our current implementations rely on the inputs being 4D.
|
||||
TF_LITE_ENSURE_EQ(context, NumDimensions(input), 4);
|
||||
@ -96,9 +100,13 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
||||
auto* params =
|
||||
reinterpret_cast<TfLiteResizeBilinearParams*>(node->builtin_data);
|
||||
|
||||
const TfLiteTensor* input = GetInput(context, node, kInputTensor);
|
||||
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
|
||||
const TfLiteTensor* size = GetInput(context, node, kSizeTensor);
|
||||
const TfLiteTensor* input;
|
||||
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputTensor, &input));
|
||||
TfLiteTensor* output;
|
||||
TF_LITE_ENSURE_OK(context,
|
||||
GetOutputSafe(context, node, kOutputTensor, &output));
|
||||
const TfLiteTensor* size;
|
||||
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kSizeTensor, &size));
|
||||
|
||||
if (IsDynamicTensor(output)) {
|
||||
TF_LITE_ENSURE_OK(context,
|
||||
|
@ -60,9 +60,13 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
TF_LITE_ENSURE_EQ(context, NumInputs(node), 2);
|
||||
TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
|
||||
|
||||
const TfLiteTensor* input = GetInput(context, node, kInputTensor);
|
||||
const TfLiteTensor* size = GetInput(context, node, kSizeTensor);
|
||||
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
|
||||
const TfLiteTensor* input;
|
||||
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputTensor, &input));
|
||||
const TfLiteTensor* size;
|
||||
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kSizeTensor, &size));
|
||||
TfLiteTensor* output;
|
||||
TF_LITE_ENSURE_OK(context,
|
||||
GetOutputSafe(context, node, kOutputTensor, &output));
|
||||
|
||||
// TODO(ahentz): Our current implementations rely on the input being 4D,
|
||||
// and the size being 1D tensor with exactly 2 elements.
|
||||
@ -85,9 +89,13 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
||||
auto* params =
|
||||
reinterpret_cast<TfLiteResizeNearestNeighborParams*>(node->builtin_data);
|
||||
|
||||
const TfLiteTensor* input = GetInput(context, node, kInputTensor);
|
||||
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
|
||||
const TfLiteTensor* size = GetInput(context, node, kSizeTensor);
|
||||
const TfLiteTensor* input;
|
||||
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputTensor, &input));
|
||||
TfLiteTensor* output;
|
||||
TF_LITE_ENSURE_OK(context,
|
||||
GetOutputSafe(context, node, kOutputTensor, &output));
|
||||
const TfLiteTensor* size;
|
||||
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kSizeTensor, &size));
|
||||
|
||||
if (IsDynamicTensor(output)) {
|
||||
TF_LITE_ENSURE_OK(context,
|
||||
|
@ -35,8 +35,10 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
TF_LITE_ENSURE_EQ(context, NumInputs(node), 2);
|
||||
TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
|
||||
|
||||
const TfLiteTensor* input = GetInput(context, node, kInputTensor);
|
||||
const TfLiteTensor* axis = GetInput(context, node, kAxisTensor);
|
||||
const TfLiteTensor* input;
|
||||
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputTensor, &input));
|
||||
const TfLiteTensor* axis;
|
||||
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kAxisTensor, &axis));
|
||||
TF_LITE_ENSURE_EQ(context, NumDimensions(axis), 1);
|
||||
TF_LITE_ENSURE(context, NumDimensions(input) >= NumElements(axis));
|
||||
|
||||
@ -59,7 +61,9 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
context->ReportError(context, "Current does not support more than 1 axis.");
|
||||
}
|
||||
|
||||
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
|
||||
TfLiteTensor* output;
|
||||
TF_LITE_ENSURE_OK(context,
|
||||
GetOutputSafe(context, node, kOutputTensor, &output));
|
||||
TfLiteIntArray* output_shape = TfLiteIntArrayCopy(input->dims);
|
||||
TF_LITE_ENSURE_TYPES_EQ(context, output->type, input->type);
|
||||
|
||||
@ -67,8 +71,11 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
}
|
||||
|
||||
TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
||||
const TfLiteTensor* input = GetInput(context, node, kInputTensor);
|
||||
const TfLiteTensor* axis_tensor = GetInput(context, node, kAxisTensor);
|
||||
const TfLiteTensor* input;
|
||||
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputTensor, &input));
|
||||
const TfLiteTensor* axis_tensor;
|
||||
TF_LITE_ENSURE_OK(context,
|
||||
GetInputSafe(context, node, kAxisTensor, &axis_tensor));
|
||||
int axis = GetTensorData<int32_t>(axis_tensor)[0];
|
||||
const int rank = NumDimensions(input);
|
||||
if (axis < 0) {
|
||||
@ -76,7 +83,9 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
||||
}
|
||||
|
||||
TF_LITE_ENSURE(context, axis >= 0 && axis < rank);
|
||||
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
|
||||
TfLiteTensor* output;
|
||||
TF_LITE_ENSURE_OK(context,
|
||||
GetOutputSafe(context, node, kOutputTensor, &output));
|
||||
|
||||
switch (output->type) {
|
||||
case kTfLiteFloat32: {
|
||||
|
@ -36,8 +36,11 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
TF_LITE_ENSURE_EQ(context, NumInputs(node), 2);
|
||||
TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
|
||||
|
||||
const TfLiteTensor* input = GetInput(context, node, kInputTensor);
|
||||
const TfLiteTensor* seq_lengths = GetInput(context, node, kSeqLengthsTensor);
|
||||
const TfLiteTensor* input;
|
||||
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputTensor, &input));
|
||||
const TfLiteTensor* seq_lengths;
|
||||
TF_LITE_ENSURE_OK(
|
||||
context, GetInputSafe(context, node, kSeqLengthsTensor, &seq_lengths));
|
||||
TF_LITE_ENSURE_EQ(context, NumDimensions(seq_lengths), 1);
|
||||
|
||||
if (input->type != kTfLiteInt32 && input->type != kTfLiteFloat32 &&
|
||||
@ -56,7 +59,9 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
return kTfLiteError;
|
||||
}
|
||||
|
||||
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
|
||||
TfLiteTensor* output;
|
||||
TF_LITE_ENSURE_OK(context,
|
||||
GetOutputSafe(context, node, kOutputTensor, &output));
|
||||
TfLiteIntArray* output_shape = TfLiteIntArrayCopy(input->dims);
|
||||
TF_LITE_ENSURE_TYPES_EQ(context, output->type, input->type);
|
||||
|
||||
@ -65,9 +70,11 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
|
||||
template <typename T, typename TS>
|
||||
TfLiteStatus ReverseSequenceImpl(TfLiteContext* context, TfLiteNode* node) {
|
||||
const TfLiteTensor* input = GetInput(context, node, kInputTensor);
|
||||
const TfLiteTensor* seq_lengths_tensor =
|
||||
GetInput(context, node, kSeqLengthsTensor);
|
||||
const TfLiteTensor* input;
|
||||
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputTensor, &input));
|
||||
const TfLiteTensor* seq_lengths_tensor;
|
||||
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kSeqLengthsTensor,
|
||||
&seq_lengths_tensor));
|
||||
const TS* seq_lengths = GetTensorData<TS>(seq_lengths_tensor);
|
||||
|
||||
auto* params =
|
||||
@ -86,7 +93,9 @@ TfLiteStatus ReverseSequenceImpl(TfLiteContext* context, TfLiteNode* node) {
|
||||
TF_LITE_ENSURE(context, seq_lengths[i] <= SizeOfDimension(input, seq_dim));
|
||||
}
|
||||
|
||||
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
|
||||
TfLiteTensor* output;
|
||||
TF_LITE_ENSURE_OK(context,
|
||||
GetOutputSafe(context, node, kOutputTensor, &output));
|
||||
|
||||
reference_ops::ReverseSequence<T, TS>(
|
||||
seq_lengths, seq_dim, batch_dim, GetTensorShape(input),
|
||||
@ -98,8 +107,9 @@ TfLiteStatus ReverseSequenceImpl(TfLiteContext* context, TfLiteNode* node) {
|
||||
|
||||
template <typename T>
|
||||
TfLiteStatus ReverseSequenceHelper(TfLiteContext* context, TfLiteNode* node) {
|
||||
const TfLiteTensor* seq_lengths_tensor =
|
||||
GetInput(context, node, kSeqLengthsTensor);
|
||||
const TfLiteTensor* seq_lengths_tensor;
|
||||
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kSeqLengthsTensor,
|
||||
&seq_lengths_tensor));
|
||||
switch (seq_lengths_tensor->type) {
|
||||
case kTfLiteInt32: {
|
||||
return ReverseSequenceImpl<T, int32_t>(context, node);
|
||||
@ -119,7 +129,9 @@ TfLiteStatus ReverseSequenceHelper(TfLiteContext* context, TfLiteNode* node) {
|
||||
}
|
||||
|
||||
TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
||||
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
|
||||
TfLiteTensor* output;
|
||||
TF_LITE_ENSURE_OK(context,
|
||||
GetOutputSafe(context, node, kOutputTensor, &output));
|
||||
|
||||
switch (output->type) {
|
||||
case kTfLiteFloat32: {
|
||||
|
@ -73,16 +73,20 @@ static TfLiteStatus InitTemporaryTensors(TfLiteContext* context,
|
||||
data->fft_double_working_area_id = first_new_index + 1;
|
||||
|
||||
// Set up FFT integer working area buffer.
|
||||
TfLiteTensor* fft_integer_working_area =
|
||||
GetTemporary(context, node, kFftIntegerWorkingAreaTensor);
|
||||
TfLiteTensor* fft_integer_working_area;
|
||||
TF_LITE_ENSURE_OK(
|
||||
context, GetTemporarySafe(context, node, kFftIntegerWorkingAreaTensor,
|
||||
&fft_integer_working_area));
|
||||
fft_integer_working_area->type = kTfLiteInt32;
|
||||
// If fft_length is not a constant tensor, fft_integer_working_area will be
|
||||
// set to dynamic later in Prepare.
|
||||
fft_integer_working_area->allocation_type = kTfLiteArenaRw;
|
||||
|
||||
// Set up FFT double working area buffer.
|
||||
TfLiteTensor* fft_double_working_area =
|
||||
GetTemporary(context, node, kFftDoubleWorkingAreaTensor);
|
||||
TfLiteTensor* fft_double_working_area;
|
||||
TF_LITE_ENSURE_OK(context,
|
||||
GetTemporarySafe(context, node, kFftDoubleWorkingAreaTensor,
|
||||
&fft_double_working_area));
|
||||
// fft_double_working_area is a double tensor. Ideally, double should be
|
||||
// added into tflite data types. However, since fft_double_working_area is a
|
||||
// temporary tensor, and there are no ops having double input/output tensors
|
||||
@ -100,10 +104,13 @@ static TfLiteStatus InitTemporaryTensors(TfLiteContext* context,
|
||||
|
||||
TfLiteStatus ResizeOutputandTemporaryTensors(TfLiteContext* context,
|
||||
TfLiteNode* node) {
|
||||
const TfLiteTensor* input = GetInput(context, node, kInputTensor);
|
||||
const TfLiteTensor* input;
|
||||
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputTensor, &input));
|
||||
const int num_dims = NumDimensions(input);
|
||||
TF_LITE_ENSURE(context, num_dims >= 2);
|
||||
const TfLiteTensor* fft_length = GetInput(context, node, kFftLengthTensor);
|
||||
const TfLiteTensor* fft_length;
|
||||
TF_LITE_ENSURE_OK(context,
|
||||
GetInputSafe(context, node, kFftLengthTensor, &fft_length));
|
||||
const int32_t* fft_length_data = GetTensorData<int32_t>(fft_length);
|
||||
// The lib, fft2d, can only handle fft_lengths of power of 2.
|
||||
TF_LITE_ENSURE(context, IsPowerOfTwo(fft_length_data[0]));
|
||||
@ -116,15 +123,19 @@ TfLiteStatus ResizeOutputandTemporaryTensors(TfLiteContext* context,
|
||||
int half_fft_working_length = fft_working_length / 2;
|
||||
|
||||
// Resize output tensor.
|
||||
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
|
||||
TfLiteTensor* output;
|
||||
TF_LITE_ENSURE_OK(context,
|
||||
GetOutputSafe(context, node, kOutputTensor, &output));
|
||||
TfLiteIntArray* output_shape = TfLiteIntArrayCopy(input->dims);
|
||||
output_shape->data[num_dims - 2] = fft_length_data[0];
|
||||
output_shape->data[num_dims - 1] = fft_length_data[1] / 2 + 1;
|
||||
TF_LITE_ENSURE_STATUS(context->ResizeTensor(context, output, output_shape));
|
||||
|
||||
// Resize temporary tensors, fft_integer_working_area.
|
||||
TfLiteTensor* fft_integer_working_area =
|
||||
GetTemporary(context, node, kFftIntegerWorkingAreaTensor);
|
||||
TfLiteTensor* fft_integer_working_area;
|
||||
TF_LITE_ENSURE_OK(
|
||||
context, GetTemporarySafe(context, node, kFftIntegerWorkingAreaTensor,
|
||||
&fft_integer_working_area));
|
||||
TfLiteIntArray* fft_integer_working_area_shape = TfLiteIntArrayCreate(1);
|
||||
fft_integer_working_area_shape->data[0] =
|
||||
2 + static_cast<int>(sqrt(fft_working_length));
|
||||
@ -132,8 +143,10 @@ TfLiteStatus ResizeOutputandTemporaryTensors(TfLiteContext* context,
|
||||
fft_integer_working_area_shape));
|
||||
|
||||
// Resize temporary tensors, fft_double_working_area.
|
||||
TfLiteTensor* fft_double_working_area =
|
||||
GetTemporary(context, node, kFftDoubleWorkingAreaTensor);
|
||||
TfLiteTensor* fft_double_working_area;
|
||||
TF_LITE_ENSURE_OK(context,
|
||||
GetTemporarySafe(context, node, kFftDoubleWorkingAreaTensor,
|
||||
&fft_double_working_area));
|
||||
TfLiteIntArray* fft_double_working_area_shape = TfLiteIntArrayCreate(1);
|
||||
fft_double_working_area_shape->data[0] =
|
||||
half_fft_working_length + fft_width / 4;
|
||||
@ -157,7 +170,8 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
|
||||
|
||||
// Check type and shape of the input tensor
|
||||
const TfLiteTensor* input = GetInput(context, node, kInputTensor);
|
||||
const TfLiteTensor* input;
|
||||
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputTensor, &input));
|
||||
TF_LITE_ENSURE(context, NumDimensions(input) >= 2);
|
||||
if (input->type != kTfLiteFloat32) {
|
||||
context->ReportError(context,
|
||||
@ -167,7 +181,9 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
}
|
||||
|
||||
// Check type and shape of the fft_length tensor
|
||||
const TfLiteTensor* fft_length = GetInput(context, node, kFftLengthTensor);
|
||||
const TfLiteTensor* fft_length;
|
||||
TF_LITE_ENSURE_OK(context,
|
||||
GetInputSafe(context, node, kFftLengthTensor, &fft_length));
|
||||
const RuntimeShape fft_length_shape = GetTensorShape(fft_length);
|
||||
|
||||
TF_LITE_ENSURE_EQ(context, NumDimensions(fft_length), 1);
|
||||
@ -183,17 +199,23 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
TF_LITE_ENSURE_STATUS(InitTemporaryTensors(context, node));
|
||||
|
||||
// Set output type
|
||||
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
|
||||
TfLiteTensor* output;
|
||||
TF_LITE_ENSURE_OK(context,
|
||||
GetOutputSafe(context, node, kOutputTensor, &output));
|
||||
output->type = kTfLiteComplex64;
|
||||
|
||||
// Exit early if fft_length is a non-const tensor. Set output tensor and
|
||||
// temporary tensors to dynamic, so that their tensor sizes can be determined
|
||||
// in Eval.
|
||||
if (!IsConstantTensor(fft_length)) {
|
||||
TfLiteTensor* fft_integer_working_area =
|
||||
GetTemporary(context, node, kFftIntegerWorkingAreaTensor);
|
||||
TfLiteTensor* fft_double_working_area =
|
||||
GetTemporary(context, node, kFftDoubleWorkingAreaTensor);
|
||||
TfLiteTensor* fft_integer_working_area;
|
||||
TF_LITE_ENSURE_OK(
|
||||
context, GetTemporarySafe(context, node, kFftIntegerWorkingAreaTensor,
|
||||
&fft_integer_working_area));
|
||||
TfLiteTensor* fft_double_working_area;
|
||||
TF_LITE_ENSURE_OK(
|
||||
context, GetTemporarySafe(context, node, kFftDoubleWorkingAreaTensor,
|
||||
&fft_double_working_area));
|
||||
SetTensorToDynamic(fft_integer_working_area);
|
||||
SetTensorToDynamic(fft_double_working_area);
|
||||
SetTensorToDynamic(output);
|
||||
@ -325,11 +347,16 @@ void PrepareOutputBuffer(complex<float>* output_data, int fft_height,
|
||||
}
|
||||
|
||||
TfLiteStatus Rfft2dHelper(TfLiteContext* context, TfLiteNode* node) {
|
||||
const TfLiteTensor* input = GetInput(context, node, kInputTensor);
|
||||
const TfLiteTensor* input;
|
||||
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputTensor, &input));
|
||||
const float* input_data = GetTensorData<float>(input);
|
||||
const TfLiteTensor* fft_length = GetInput(context, node, kFftLengthTensor);
|
||||
const TfLiteTensor* fft_length;
|
||||
TF_LITE_ENSURE_OK(context,
|
||||
GetInputSafe(context, node, kFftLengthTensor, &fft_length));
|
||||
const int32_t* fft_length_data = GetTensorData<int32_t>(fft_length);
|
||||
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
|
||||
TfLiteTensor* output;
|
||||
TF_LITE_ENSURE_OK(context,
|
||||
GetOutputSafe(context, node, kOutputTensor, &output));
|
||||
complex<float>* output_data = GetTensorData<complex<float>>(output);
|
||||
|
||||
int fft_height, fft_width;
|
||||
@ -358,14 +385,18 @@ TfLiteStatus Rfft2dHelper(TfLiteContext* context, TfLiteNode* node) {
|
||||
}
|
||||
|
||||
// Get buffer for integer working area.
|
||||
TfLiteTensor* fft_integer_working_area =
|
||||
GetTemporary(context, node, kFftIntegerWorkingAreaTensor);
|
||||
TfLiteTensor* fft_integer_working_area;
|
||||
TF_LITE_ENSURE_OK(
|
||||
context, GetTemporarySafe(context, node, kFftIntegerWorkingAreaTensor,
|
||||
&fft_integer_working_area));
|
||||
int* fft_integer_working_area_data =
|
||||
GetTensorData<int>(fft_integer_working_area);
|
||||
|
||||
// Get buffer for double working area.
|
||||
TfLiteTensor* fft_double_working_area =
|
||||
GetTemporary(context, node, kFftDoubleWorkingAreaTensor);
|
||||
TfLiteTensor* fft_double_working_area;
|
||||
TF_LITE_ENSURE_OK(context,
|
||||
GetTemporarySafe(context, node, kFftDoubleWorkingAreaTensor,
|
||||
&fft_double_working_area));
|
||||
// Get double value out of the memory of fft_double_working_area_data.
|
||||
double* fft_double_working_area_data = reinterpret_cast<double*>(
|
||||
GetTensorData<int64_t>(fft_double_working_area));
|
||||
@ -393,10 +424,15 @@ TfLiteStatus Rfft2dHelper(TfLiteContext* context, TfLiteNode* node) {
|
||||
}
|
||||
|
||||
TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
||||
const TfLiteTensor* input = GetInput(context, node, kInputTensor);
|
||||
const TfLiteTensor* fft_length = GetInput(context, node, kFftLengthTensor);
|
||||
const TfLiteTensor* input;
|
||||
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputTensor, &input));
|
||||
const TfLiteTensor* fft_length;
|
||||
TF_LITE_ENSURE_OK(context,
|
||||
GetInputSafe(context, node, kFftLengthTensor, &fft_length));
|
||||
const int32_t* fft_length_data = GetTensorData<int32_t>(fft_length);
|
||||
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
|
||||
TfLiteTensor* output;
|
||||
TF_LITE_ENSURE_OK(context,
|
||||
GetOutputSafe(context, node, kOutputTensor, &output));
|
||||
|
||||
if (output->type != kTfLiteComplex64) {
|
||||
context->ReportError(context,
|
||||
|
@ -30,8 +30,11 @@ constexpr int kInputTensor = 0;
|
||||
constexpr int kOutputTensor = 0;
|
||||
|
||||
TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
const TfLiteTensor* input = GetInput(context, node, kInputTensor);
|
||||
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
|
||||
const TfLiteTensor* input;
|
||||
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputTensor, &input));
|
||||
TfLiteTensor* output;
|
||||
TF_LITE_ENSURE_OK(context,
|
||||
GetOutputSafe(context, node, kOutputTensor, &output));
|
||||
TF_LITE_ENSURE_EQ(context, NumInputs(node), 1);
|
||||
TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
|
||||
TF_LITE_ENSURE_TYPES_EQ(context, input->type, kTfLiteFloat32);
|
||||
@ -41,8 +44,11 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
}
|
||||
|
||||
TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
||||
const TfLiteTensor* input = GetInput(context, node, kInputTensor);
|
||||
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
|
||||
const TfLiteTensor* input;
|
||||
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputTensor, &input));
|
||||
TfLiteTensor* output;
|
||||
TF_LITE_ENSURE_OK(context,
|
||||
GetOutputSafe(context, node, kOutputTensor, &output));
|
||||
|
||||
optimized_ops::Round(GetTensorShape(input), GetTensorData<float>(input),
|
||||
GetTensorShape(output), GetTensorData<float>(output));
|
||||
|
@ -74,9 +74,12 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
TF_LITE_ENSURE_EQ(context, NumInputs(node), 3);
|
||||
TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
|
||||
|
||||
const TfLiteTensor* indices = GetInput(context, node, kIndices);
|
||||
const TfLiteTensor* updates = GetInput(context, node, kUpdates);
|
||||
const TfLiteTensor* shape = GetInput(context, node, kShape);
|
||||
const TfLiteTensor* indices;
|
||||
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kIndices, &indices));
|
||||
const TfLiteTensor* updates;
|
||||
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kUpdates, &updates));
|
||||
const TfLiteTensor* shape;
|
||||
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kShape, &shape));
|
||||
|
||||
switch (updates->type) {
|
||||
case kTfLiteFloat32:
|
||||
@ -96,7 +99,9 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
return kTfLiteError;
|
||||
}
|
||||
|
||||
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
|
||||
TfLiteTensor* output;
|
||||
TF_LITE_ENSURE_OK(context,
|
||||
GetOutputSafe(context, node, kOutputTensor, &output));
|
||||
output->type = updates->type;
|
||||
|
||||
if (IsConstantTensor(shape)) {
|
||||
@ -163,10 +168,15 @@ TfLiteStatus EvalScatterNd(TfLiteContext* context, const TfLiteTensor* indices,
|
||||
}
|
||||
|
||||
TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
||||
const TfLiteTensor* indices = GetInput(context, node, kIndices);
|
||||
const TfLiteTensor* updates = GetInput(context, node, kUpdates);
|
||||
const TfLiteTensor* shape = GetInput(context, node, kShape);
|
||||
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
|
||||
const TfLiteTensor* indices;
|
||||
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kIndices, &indices));
|
||||
const TfLiteTensor* updates;
|
||||
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kUpdates, &updates));
|
||||
const TfLiteTensor* shape;
|
||||
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kShape, &shape));
|
||||
TfLiteTensor* output;
|
||||
TF_LITE_ENSURE_OK(context,
|
||||
GetOutputSafe(context, node, kOutputTensor, &output));
|
||||
|
||||
switch (indices->type) {
|
||||
case kTfLiteInt32:
|
||||
|
@ -64,11 +64,15 @@ TfLiteStatus ResizeOutputTensor(TfLiteContext* context,
|
||||
TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
TF_LITE_ENSURE_EQ(context, NumInputs(node), 2);
|
||||
TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
|
||||
const TfLiteTensor* data = GetInput(context, node, kInputDataTensor);
|
||||
const TfLiteTensor* segment_ids =
|
||||
GetInput(context, node, kInputSegmentIdsTensor);
|
||||
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
|
||||
|
||||
const TfLiteTensor* data;
|
||||
TF_LITE_ENSURE_OK(context,
|
||||
GetInputSafe(context, node, kInputDataTensor, &data));
|
||||
const TfLiteTensor* segment_ids;
|
||||
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputSegmentIdsTensor,
|
||||
&segment_ids));
|
||||
TfLiteTensor* output;
|
||||
TF_LITE_ENSURE_OK(context,
|
||||
GetOutputSafe(context, node, kOutputTensor, &output));
|
||||
TF_LITE_ENSURE(context,
|
||||
data->type == kTfLiteInt32 || data->type == kTfLiteFloat32);
|
||||
TF_LITE_ENSURE_EQ(context, segment_ids->type, kTfLiteInt32);
|
||||
@ -82,10 +86,15 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
}
|
||||
|
||||
TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
||||
const TfLiteTensor* data = GetInput(context, node, kInputDataTensor);
|
||||
const TfLiteTensor* segment_ids =
|
||||
GetInput(context, node, kInputSegmentIdsTensor);
|
||||
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
|
||||
const TfLiteTensor* data;
|
||||
TF_LITE_ENSURE_OK(context,
|
||||
GetInputSafe(context, node, kInputDataTensor, &data));
|
||||
const TfLiteTensor* segment_ids;
|
||||
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputSegmentIdsTensor,
|
||||
&segment_ids));
|
||||
TfLiteTensor* output;
|
||||
TF_LITE_ENSURE_OK(context,
|
||||
GetOutputSafe(context, node, kOutputTensor, &output));
|
||||
|
||||
if (IsDynamicTensor(output)) {
|
||||
TF_LITE_ENSURE_OK(context,
|
||||
|
@ -61,11 +61,18 @@ TfLiteStatus SelectPrepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
TF_LITE_ENSURE_EQ(context, NumInputs(node), 3);
|
||||
TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
|
||||
|
||||
const TfLiteTensor* input_condition =
|
||||
GetInput(context, node, kInputTensorCondition);
|
||||
const TfLiteTensor* input_x = GetInput(context, node, kInputTensorX);
|
||||
const TfLiteTensor* input_y = GetInput(context, node, kInputTensorY);
|
||||
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
|
||||
const TfLiteTensor* input_condition;
|
||||
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputTensorCondition,
|
||||
&input_condition));
|
||||
const TfLiteTensor* input_x;
|
||||
TF_LITE_ENSURE_OK(context,
|
||||
GetInputSafe(context, node, kInputTensorX, &input_x));
|
||||
const TfLiteTensor* input_y;
|
||||
TF_LITE_ENSURE_OK(context,
|
||||
GetInputSafe(context, node, kInputTensorY, &input_y));
|
||||
TfLiteTensor* output;
|
||||
TF_LITE_ENSURE_OK(context,
|
||||
GetOutputSafe(context, node, kOutputTensor, &output));
|
||||
|
||||
// Input must be bool.
|
||||
TF_LITE_ENSURE_TYPES_EQ(context, input_condition->type, kTfLiteBool);
|
||||
@ -111,11 +118,18 @@ TfLiteStatus SelectPrepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
|
||||
TfLiteStatus SelectEval(TfLiteContext* context, TfLiteNode* node) {
|
||||
OpData* data = reinterpret_cast<OpData*>(node->user_data);
|
||||
const TfLiteTensor* input_condition =
|
||||
GetInput(context, node, kInputTensorCondition);
|
||||
const TfLiteTensor* input_x = GetInput(context, node, kInputTensorX);
|
||||
const TfLiteTensor* input_y = GetInput(context, node, kInputTensorY);
|
||||
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
|
||||
const TfLiteTensor* input_condition;
|
||||
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputTensorCondition,
|
||||
&input_condition));
|
||||
const TfLiteTensor* input_x;
|
||||
TF_LITE_ENSURE_OK(context,
|
||||
GetInputSafe(context, node, kInputTensorX, &input_x));
|
||||
const TfLiteTensor* input_y;
|
||||
TF_LITE_ENSURE_OK(context,
|
||||
GetInputSafe(context, node, kInputTensorY, &input_y));
|
||||
TfLiteTensor* output;
|
||||
TF_LITE_ENSURE_OK(context,
|
||||
GetOutputSafe(context, node, kOutputTensor, &output));
|
||||
|
||||
#define TF_LITE_SELECT(type, op) \
|
||||
reference_ops::op(GetTensorShape(input_condition), \
|
||||
|
@ -40,8 +40,11 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
TF_LITE_ENSURE_EQ(context, NumInputs(node), 1);
|
||||
TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
|
||||
|
||||
const TfLiteTensor* input = GetInput(context, node, kInputTensor);
|
||||
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
|
||||
const TfLiteTensor* input;
|
||||
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputTensor, &input));
|
||||
TfLiteTensor* output;
|
||||
TF_LITE_ENSURE_OK(context,
|
||||
GetOutputSafe(context, node, kOutputTensor, &output));
|
||||
|
||||
auto* params = reinterpret_cast<TfLiteShapeParams*>(node->builtin_data);
|
||||
switch (params->out_type) {
|
||||
|
@ -35,6 +35,7 @@ limitations under the License.
|
||||
|
||||
#include "tensorflow/lite/c/builtin_op_data.h"
|
||||
#include "tensorflow/lite/c/common.h"
|
||||
#include "tensorflow/lite/kernels/internal/compatibility.h"
|
||||
#include "tensorflow/lite/kernels/kernel_util.h"
|
||||
#include "tensorflow/lite/string_util.h"
|
||||
|
||||
@ -48,10 +49,12 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
TF_LITE_ENSURE_EQ(context, NumInputs(node), 1);
|
||||
TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
|
||||
|
||||
TF_LITE_ENSURE_TYPES_EQ(context, GetInput(context, node, 0)->type,
|
||||
kTfLiteString);
|
||||
TF_LITE_ENSURE_TYPES_EQ(context, GetOutput(context, node, 0)->type,
|
||||
kTfLiteString);
|
||||
const TfLiteTensor* input_tensor;
|
||||
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 0, &input_tensor));
|
||||
TF_LITE_ENSURE_TYPES_EQ(context, input_tensor->type, kTfLiteString);
|
||||
TfLiteTensor* output_tensor;
|
||||
TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, 0, &output_tensor));
|
||||
TF_LITE_ENSURE_TYPES_EQ(context, output_tensor->type, kTfLiteString);
|
||||
return kTfLiteOk;
|
||||
}
|
||||
|
||||
@ -91,7 +94,9 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
||||
|
||||
// Split sentence to words.
|
||||
std::vector<StringRef> words;
|
||||
tflite::StringRef strref = tflite::GetString(GetInput(context, node, 0), 0);
|
||||
const TfLiteTensor* input;
|
||||
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 0, &input));
|
||||
tflite::StringRef strref = tflite::GetString(input, 0);
|
||||
int prev_idx = 0;
|
||||
for (int i = 1; i < strref.len; i++) {
|
||||
if (isspace(*(strref.str + i))) {
|
||||
|
@ -113,10 +113,15 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
TF_LITE_ENSURE_EQ(context, NumInputs(node), 3);
|
||||
TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
|
||||
|
||||
const TfLiteTensor* input = GetInput(context, node, kInputTensor);
|
||||
const TfLiteTensor* begin = GetInput(context, node, kBeginTensor);
|
||||
const TfLiteTensor* size = GetInput(context, node, kSizeTensor);
|
||||
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
|
||||
const TfLiteTensor* input;
|
||||
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputTensor, &input));
|
||||
const TfLiteTensor* begin;
|
||||
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kBeginTensor, &begin));
|
||||
const TfLiteTensor* size;
|
||||
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kSizeTensor, &size));
|
||||
TfLiteTensor* output;
|
||||
TF_LITE_ENSURE_OK(context,
|
||||
GetOutputSafe(context, node, kOutputTensor, &output));
|
||||
|
||||
// Ensure validity of input tensor and its dimension.
|
||||
TF_LITE_ENSURE_TYPES_EQ(context, input->type, output->type);
|
||||
@ -142,10 +147,15 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
|
||||
template <KernelType kernel_type>
|
||||
TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
||||
const TfLiteTensor* input = GetInput(context, node, kInputTensor);
|
||||
const TfLiteTensor* begin = GetInput(context, node, kBeginTensor);
|
||||
const TfLiteTensor* size = GetInput(context, node, kSizeTensor);
|
||||
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
|
||||
const TfLiteTensor* input;
|
||||
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputTensor, &input));
|
||||
const TfLiteTensor* begin;
|
||||
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kBeginTensor, &begin));
|
||||
const TfLiteTensor* size;
|
||||
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kSizeTensor, &size));
|
||||
TfLiteTensor* output;
|
||||
TF_LITE_ENSURE_OK(context,
|
||||
GetOutputSafe(context, node, kOutputTensor, &output));
|
||||
|
||||
if (IsDynamicTensor(output)) {
|
||||
TF_LITE_ENSURE_OK(context,
|
||||
|
@ -45,8 +45,11 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
TF_LITE_ENSURE_EQ(context, NumInputs(node), 1);
|
||||
TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
|
||||
|
||||
const TfLiteTensor* input = GetInput(context, node, kInputTensor);
|
||||
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
|
||||
const TfLiteTensor* input;
|
||||
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputTensor, &input));
|
||||
TfLiteTensor* output;
|
||||
TF_LITE_ENSURE_OK(context,
|
||||
GetOutputSafe(context, node, kOutputTensor, &output));
|
||||
|
||||
TF_LITE_ENSURE_EQ(context, NumDimensions(input), 4);
|
||||
|
||||
@ -80,8 +83,11 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
||||
auto* params =
|
||||
reinterpret_cast<TfLiteSpaceToDepthParams*>(node->builtin_data);
|
||||
|
||||
const TfLiteTensor* input = GetInput(context, node, kInputTensor);
|
||||
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
|
||||
const TfLiteTensor* input;
|
||||
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputTensor, &input));
|
||||
TfLiteTensor* output;
|
||||
TF_LITE_ENSURE_OK(context,
|
||||
GetOutputSafe(context, node, kOutputTensor, &output));
|
||||
|
||||
#define TF_LITE_SPACE_TO_DEPTH(type, scalar) \
|
||||
tflite::SpaceToDepthParams op_params; \
|
||||
|
@ -143,12 +143,18 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
TF_LITE_ENSURE_EQ(context, NumInputs(node), 4);
|
||||
TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
|
||||
|
||||
const TfLiteTensor* indices = GetInput(context, node, kIndicesTensor);
|
||||
const TfLiteTensor* output_shape =
|
||||
GetInput(context, node, kOutputShapeTensor);
|
||||
const TfLiteTensor* values = GetInput(context, node, kValueInputTensor);
|
||||
const TfLiteTensor* default_value =
|
||||
GetInput(context, node, kDefaultValueTensor);
|
||||
const TfLiteTensor* indices;
|
||||
TF_LITE_ENSURE_OK(context,
|
||||
GetInputSafe(context, node, kIndicesTensor, &indices));
|
||||
const TfLiteTensor* output_shape;
|
||||
TF_LITE_ENSURE_OK(
|
||||
context, GetInputSafe(context, node, kOutputShapeTensor, &output_shape));
|
||||
const TfLiteTensor* values;
|
||||
TF_LITE_ENSURE_OK(context,
|
||||
GetInputSafe(context, node, kValueInputTensor, &values));
|
||||
const TfLiteTensor* default_value;
|
||||
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kDefaultValueTensor,
|
||||
&default_value));
|
||||
|
||||
// TODO(renjieliu): Handle validate_indices.
|
||||
|
||||
@ -178,7 +184,9 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
TF_LITE_ENSURE_OK(
|
||||
context, CheckDimensionsMatch(context, indices, output_shape, values));
|
||||
|
||||
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
|
||||
TfLiteTensor* output;
|
||||
TF_LITE_ENSURE_OK(context,
|
||||
GetOutputSafe(context, node, kOutputTensor, &output));
|
||||
output->type = values->type;
|
||||
TF_LITE_ENSURE_EQ(context, NumDimensions(output_shape), 1);
|
||||
|
||||
@ -191,13 +199,21 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
|
||||
template <typename T, typename TI>
|
||||
TfLiteStatus SparseToDenseImpl(TfLiteContext* context, TfLiteNode* node) {
|
||||
const TfLiteTensor* indices = GetInput(context, node, kIndicesTensor);
|
||||
const TfLiteTensor* output_shape =
|
||||
GetInput(context, node, kOutputShapeTensor);
|
||||
const TfLiteTensor* values = GetInput(context, node, kValueInputTensor);
|
||||
const TfLiteTensor* default_value =
|
||||
GetInput(context, node, kDefaultValueTensor);
|
||||
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
|
||||
const TfLiteTensor* indices;
|
||||
TF_LITE_ENSURE_OK(context,
|
||||
GetInputSafe(context, node, kIndicesTensor, &indices));
|
||||
const TfLiteTensor* output_shape;
|
||||
TF_LITE_ENSURE_OK(
|
||||
context, GetInputSafe(context, node, kOutputShapeTensor, &output_shape));
|
||||
const TfLiteTensor* values;
|
||||
TF_LITE_ENSURE_OK(context,
|
||||
GetInputSafe(context, node, kValueInputTensor, &values));
|
||||
const TfLiteTensor* default_value;
|
||||
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kDefaultValueTensor,
|
||||
&default_value));
|
||||
TfLiteTensor* output;
|
||||
TF_LITE_ENSURE_OK(context,
|
||||
GetOutputSafe(context, node, kOutputTensor, &output));
|
||||
|
||||
if (IsDynamicTensor(output)) {
|
||||
TF_LITE_ENSURE_OK(context,
|
||||
@ -238,8 +254,12 @@ TfLiteStatus EvalForIndexType(TfLiteContext* context, TfLiteNode* node,
|
||||
}
|
||||
|
||||
TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
||||
const TfLiteTensor* indices = GetInput(context, node, kIndicesTensor);
|
||||
const TfLiteTensor* values = GetInput(context, node, kValueInputTensor);
|
||||
const TfLiteTensor* indices;
|
||||
TF_LITE_ENSURE_OK(context,
|
||||
GetInputSafe(context, node, kIndicesTensor, &indices));
|
||||
const TfLiteTensor* values;
|
||||
TF_LITE_ENSURE_OK(context,
|
||||
GetInputSafe(context, node, kValueInputTensor, &values));
|
||||
|
||||
switch (values->type) {
|
||||
case kTfLiteFloat32:
|
||||
|
@ -41,7 +41,9 @@ struct OpContext {
|
||||
|
||||
TfLiteStatus UseDynamicOutputTensors(TfLiteContext* context, TfLiteNode* node) {
|
||||
for (int i = 0; i < NumOutputs(node); ++i) {
|
||||
SetTensorToDynamic(GetOutput(context, node, i));
|
||||
TfLiteTensor* tensor;
|
||||
TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, i, &tensor));
|
||||
SetTensorToDynamic(tensor);
|
||||
}
|
||||
return kTfLiteOk;
|
||||
}
|
||||
@ -65,7 +67,8 @@ TfLiteStatus ResizeOutputTensors(TfLiteContext* context, TfLiteNode* node,
|
||||
for (int i = 0; i < NumOutputs(node); ++i) {
|
||||
TfLiteIntArray* output_dims = TfLiteIntArrayCopy(input->dims);
|
||||
output_dims->data[axis_value] = slice_size;
|
||||
TfLiteTensor* output = GetOutput(context, node, i);
|
||||
TfLiteTensor* output;
|
||||
TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, i, &output));
|
||||
TF_LITE_ENSURE_STATUS(context->ResizeTensor(context, output, output_dims));
|
||||
}
|
||||
|
||||
@ -85,7 +88,9 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
input_type == kTfLiteInt8 || input_type == kTfLiteInt16 ||
|
||||
input_type == kTfLiteInt32);
|
||||
for (int i = 0; i < NumOutputs(node); ++i) {
|
||||
GetOutput(context, node, i)->type = input_type;
|
||||
TfLiteTensor* tensor;
|
||||
TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, i, &tensor));
|
||||
tensor->type = input_type;
|
||||
}
|
||||
|
||||
// If we know the contents of the 'axis' tensor, resize all outputs.
|
||||
|
@ -45,7 +45,9 @@ struct OpContext {
|
||||
|
||||
TfLiteStatus UseDynamicOutputTensors(TfLiteContext* context, TfLiteNode* node) {
|
||||
for (int i = 0; i < NumOutputs(node); ++i) {
|
||||
SetTensorToDynamic(GetOutput(context, node, i));
|
||||
TfLiteTensor* tensor;
|
||||
TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, i, &tensor));
|
||||
SetTensorToDynamic(tensor);
|
||||
}
|
||||
return kTfLiteOk;
|
||||
}
|
||||
@ -113,7 +115,8 @@ TfLiteStatus ResizeOutputTensors(TfLiteContext* context, TfLiteNode* node,
|
||||
for (int i = 0; i < NumOutputs(node); ++i) {
|
||||
TfLiteIntArray* output_dims = TfLiteIntArrayCopy(input->dims);
|
||||
output_dims->data[axis_value] = size_splits_vector.at(i);
|
||||
TfLiteTensor* output = GetOutput(context, node, i);
|
||||
TfLiteTensor* output;
|
||||
TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, i, &output));
|
||||
TF_LITE_ENSURE_STATUS(context->ResizeTensor(context, output, output_dims));
|
||||
}
|
||||
|
||||
@ -133,7 +136,9 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
input_type == kTfLiteInt16 || input_type == kTfLiteInt32 ||
|
||||
input_type == kTfLiteInt64 || input_type == kTfLiteInt8);
|
||||
for (int i = 0; i < NumOutputs(node); ++i) {
|
||||
GetOutput(context, node, i)->type = input_type;
|
||||
TfLiteTensor* tensor;
|
||||
TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, i, &tensor));
|
||||
tensor->type = input_type;
|
||||
}
|
||||
|
||||
auto size_splits = op_context.size_splits;
|
||||
|
@ -60,9 +60,15 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
TF_LITE_ENSURE_EQ(context, NumInputs(node), 2);
|
||||
TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
|
||||
|
||||
const TfLiteTensor* input1 = GetInput(context, node, kInputTensor1);
|
||||
const TfLiteTensor* input2 = GetInput(context, node, kInputTensor2);
|
||||
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
|
||||
const TfLiteTensor* input1;
|
||||
TF_LITE_ENSURE_OK(context,
|
||||
GetInputSafe(context, node, kInputTensor1, &input1));
|
||||
const TfLiteTensor* input2;
|
||||
TF_LITE_ENSURE_OK(context,
|
||||
GetInputSafe(context, node, kInputTensor2, &input2));
|
||||
TfLiteTensor* output;
|
||||
TF_LITE_ENSURE_OK(context,
|
||||
GetOutputSafe(context, node, kOutputTensor, &output));
|
||||
|
||||
TF_LITE_ENSURE_TYPES_EQ(context, input1->type, input2->type);
|
||||
output->type = input2->type;
|
||||
@ -101,9 +107,15 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
||||
OpData* data = reinterpret_cast<OpData*>(node->user_data);
|
||||
ruy::profiler::ScopeLabel label("SquaredDifference");
|
||||
|
||||
const TfLiteTensor* input1 = GetInput(context, node, kInputTensor1);
|
||||
const TfLiteTensor* input2 = GetInput(context, node, kInputTensor2);
|
||||
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
|
||||
const TfLiteTensor* input1;
|
||||
TF_LITE_ENSURE_OK(context,
|
||||
GetInputSafe(context, node, kInputTensor1, &input1));
|
||||
const TfLiteTensor* input2;
|
||||
TF_LITE_ENSURE_OK(context,
|
||||
GetInputSafe(context, node, kInputTensor2, &input2));
|
||||
TfLiteTensor* output;
|
||||
TF_LITE_ENSURE_OK(context,
|
||||
GetOutputSafe(context, node, kOutputTensor, &output));
|
||||
|
||||
if (output->type == kTfLiteFloat32) {
|
||||
EvalSquaredDifference<float>(context, node, data, input1, input2, output);
|
||||
|
@ -217,9 +217,15 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
TF_LITE_ENSURE_EQ(context, NumInputs(node), 2);
|
||||
TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
|
||||
|
||||
const TfLiteTensor* input1 = GetInput(context, node, kInputTensor1);
|
||||
const TfLiteTensor* input2 = GetInput(context, node, kInputTensor2);
|
||||
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
|
||||
const TfLiteTensor* input1;
|
||||
TF_LITE_ENSURE_OK(context,
|
||||
GetInputSafe(context, node, kInputTensor1, &input1));
|
||||
const TfLiteTensor* input2;
|
||||
TF_LITE_ENSURE_OK(context,
|
||||
GetInputSafe(context, node, kInputTensor2, &input2));
|
||||
TfLiteTensor* output;
|
||||
TF_LITE_ENSURE_OK(context,
|
||||
GetOutputSafe(context, node, kOutputTensor, &output));
|
||||
|
||||
TF_LITE_ENSURE_TYPES_EQ(context, input1->type, input2->type);
|
||||
output->type = input2->type;
|
||||
@ -435,9 +441,15 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
||||
auto* params = reinterpret_cast<TfLiteSubParams*>(node->builtin_data);
|
||||
OpData* data = reinterpret_cast<OpData*>(node->user_data);
|
||||
|
||||
const TfLiteTensor* input1 = GetInput(context, node, kInputTensor1);
|
||||
const TfLiteTensor* input2 = GetInput(context, node, kInputTensor2);
|
||||
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
|
||||
const TfLiteTensor* input1;
|
||||
TF_LITE_ENSURE_OK(context,
|
||||
GetInputSafe(context, node, kInputTensor1, &input1));
|
||||
const TfLiteTensor* input2;
|
||||
TF_LITE_ENSURE_OK(context,
|
||||
GetInputSafe(context, node, kInputTensor2, &input2));
|
||||
TfLiteTensor* output;
|
||||
TF_LITE_ENSURE_OK(context,
|
||||
GetOutputSafe(context, node, kOutputTensor, &output));
|
||||
|
||||
if (output->type == kTfLiteFloat32 || output->type == kTfLiteInt32 ||
|
||||
output->type == kTfLiteInt64) {
|
||||
|
@ -82,11 +82,14 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
TF_LITE_ENSURE_EQ(context, node->outputs->size, 1);
|
||||
TF_LITE_ENSURE_EQ(context, node->inputs->size, 5);
|
||||
|
||||
const TfLiteTensor* input = GetInput(context, node, kInputTensor);
|
||||
const TfLiteTensor* weights_feature =
|
||||
GetInput(context, node, kWeightsFeatureTensor);
|
||||
const TfLiteTensor* weights_time =
|
||||
GetInput(context, node, kWeightsTimeTensor);
|
||||
const TfLiteTensor* input;
|
||||
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputTensor, &input));
|
||||
const TfLiteTensor* weights_feature;
|
||||
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kWeightsFeatureTensor,
|
||||
&weights_feature));
|
||||
const TfLiteTensor* weights_time;
|
||||
TF_LITE_ENSURE_OK(
|
||||
context, GetInputSafe(context, node, kWeightsTimeTensor, &weights_time));
|
||||
|
||||
TF_LITE_ENSURE(context,
|
||||
input->type == kTfLiteFloat32 || input->type == kTfLiteInt8);
|
||||
@ -108,8 +111,11 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
TF_LITE_ENSURE_EQ(context, bias->dims->data[0], num_units);
|
||||
}
|
||||
|
||||
const TfLiteTensor* state = GetInput(context, node, kStateTensor);
|
||||
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
|
||||
const TfLiteTensor* state;
|
||||
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kStateTensor, &state));
|
||||
TfLiteTensor* output;
|
||||
TF_LITE_ENSURE_OK(context,
|
||||
GetOutputSafe(context, node, kOutputTensor, &output));
|
||||
|
||||
// Check the shape of input state tensors.
|
||||
TF_LITE_ENSURE_EQ(context, NumDimensions(state), 2);
|
||||
@ -143,7 +149,9 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
scratch_size_array->data[0] = batch_size;
|
||||
scratch_size_array->data[1] = num_filters;
|
||||
|
||||
TfLiteTensor* scratch_tensor = GetTemporary(context, node, /*index=*/0);
|
||||
TfLiteTensor* scratch_tensor;
|
||||
TF_LITE_ENSURE_OK(
|
||||
context, GetTemporarySafe(context, node, /*index=*/0, &scratch_tensor));
|
||||
|
||||
// The scratch buffer is of type int32 for full integer svdf and it's of type
|
||||
// float32 for hybrid and float case.
|
||||
@ -161,7 +169,9 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
// Tell interpreter to allocate temporary tensors to store quantized values
|
||||
// of input tensors.
|
||||
node->temporaries->data[1] = scratch_tensor_index + 1;
|
||||
TfLiteTensor* input_quantized = GetTemporary(context, node, /*index=*/1);
|
||||
TfLiteTensor* input_quantized;
|
||||
TF_LITE_ENSURE_OK(context, GetTemporarySafe(context, node, /*index=*/1,
|
||||
&input_quantized));
|
||||
input_quantized->type = weights_feature->type;
|
||||
input_quantized->allocation_type = kTfLiteArenaRw;
|
||||
if (!TfLiteIntArrayEqual(input_quantized->dims, input->dims)) {
|
||||
@ -172,7 +182,9 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
|
||||
// Tell interpreter to allocate temporary tensors to store scaling factors.
|
||||
node->temporaries->data[2] = scratch_tensor_index + 2;
|
||||
TfLiteTensor* scaling_factors = GetTemporary(context, node, /*index=*/2);
|
||||
TfLiteTensor* scaling_factors;
|
||||
TF_LITE_ENSURE_OK(context, GetTemporarySafe(context, node, /*index=*/2,
|
||||
&scaling_factors));
|
||||
scaling_factors->type = kTfLiteFloat32;
|
||||
scaling_factors->allocation_type = kTfLiteArenaRw;
|
||||
int scaling_dims[1] = {batch_size};
|
||||
@ -186,7 +198,9 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
// Used to store dequantized weights_time matrix for hybrid computation of
|
||||
// matmul(state, weights_time), which occurs in floating point.
|
||||
node->temporaries->data[3] = scratch_tensor_index + 3;
|
||||
TfLiteTensor* float_weights_time = GetTemporary(context, node, /*index=*/3);
|
||||
TfLiteTensor* float_weights_time;
|
||||
TF_LITE_ENSURE_OK(context, GetTemporarySafe(context, node, /*index=*/3,
|
||||
&float_weights_time));
|
||||
float_weights_time->type = kTfLiteFloat32;
|
||||
// Persistent so that we can compute the dequantized weights only once.
|
||||
float_weights_time->allocation_type = kTfLiteArenaRwPersistent;
|
||||
@ -199,7 +213,9 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
}
|
||||
|
||||
node->temporaries->data[4] = scratch_tensor_index + 4;
|
||||
TfLiteTensor* zero_points = GetTemporary(context, node, /*index=*/4);
|
||||
TfLiteTensor* zero_points;
|
||||
TF_LITE_ENSURE_OK(
|
||||
context, GetTemporarySafe(context, node, /*index=*/4, &zero_points));
|
||||
zero_points->type = kTfLiteFloat32;
|
||||
zero_points->allocation_type = kTfLiteArenaRw;
|
||||
int zero_points_dims[1] = {batch_size};
|
||||
@ -211,7 +227,9 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
}
|
||||
|
||||
node->temporaries->data[5] = scratch_tensor_index + 5;
|
||||
TfLiteTensor* row_sums = GetTemporary(context, node, /*index=*/5);
|
||||
TfLiteTensor* row_sums;
|
||||
TF_LITE_ENSURE_OK(context,
|
||||
GetTemporarySafe(context, node, /*index=*/5, &row_sums));
|
||||
row_sums->type = kTfLiteFloat32;
|
||||
row_sums->allocation_type = kTfLiteArenaRwPersistent;
|
||||
int row_sums_dims[1] = {num_filters};
|
||||
@ -228,7 +246,9 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
output_temp_size_array->data[0] = num_units;
|
||||
output_temp_size_array->data[1] = batch_size;
|
||||
node->temporaries->data[1] = scratch_tensor_index + 1;
|
||||
TfLiteTensor* output_temp = GetTemporary(context, node, /*index=*/1);
|
||||
TfLiteTensor* output_temp;
|
||||
TF_LITE_ENSURE_OK(
|
||||
context, GetTemporarySafe(context, node, /*index=*/1, &output_temp));
|
||||
output_temp->type = kTfLiteInt32;
|
||||
output_temp->allocation_type = kTfLiteArenaRw;
|
||||
TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, output_temp,
|
||||
@ -263,17 +283,24 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
||||
auto* params = reinterpret_cast<TfLiteSVDFParams*>(node->builtin_data);
|
||||
OpData* op_data = reinterpret_cast<OpData*>(node->user_data);
|
||||
|
||||
const TfLiteTensor* input = GetInput(context, node, kInputTensor);
|
||||
const TfLiteTensor* weights_feature =
|
||||
GetInput(context, node, kWeightsFeatureTensor);
|
||||
const TfLiteTensor* weights_time =
|
||||
GetInput(context, node, kWeightsTimeTensor);
|
||||
const TfLiteTensor* input;
|
||||
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputTensor, &input));
|
||||
const TfLiteTensor* weights_feature;
|
||||
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kWeightsFeatureTensor,
|
||||
&weights_feature));
|
||||
const TfLiteTensor* weights_time;
|
||||
TF_LITE_ENSURE_OK(
|
||||
context, GetInputSafe(context, node, kWeightsTimeTensor, &weights_time));
|
||||
const TfLiteTensor* bias = GetOptionalInputTensor(context, node, kBiasTensor);
|
||||
|
||||
TfLiteTensor* scratch = GetTemporary(context, node, /*index=*/0);
|
||||
TfLiteTensor* scratch;
|
||||
TF_LITE_ENSURE_OK(context,
|
||||
GetTemporarySafe(context, node, /*index=*/0, &scratch));
|
||||
|
||||
TfLiteTensor* state = GetVariableInput(context, node, kStateTensor);
|
||||
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
|
||||
TfLiteTensor* output;
|
||||
TF_LITE_ENSURE_OK(context,
|
||||
GetOutputSafe(context, node, kOutputTensor, &output));
|
||||
|
||||
switch (weights_feature->type) {
|
||||
case kTfLiteFloat32: {
|
||||
@ -286,14 +313,21 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
||||
case kTfLiteUInt8:
|
||||
case kTfLiteInt8: {
|
||||
if (input->type == kTfLiteFloat32) {
|
||||
TfLiteTensor* input_quantized =
|
||||
GetTemporary(context, node, /*index=*/1);
|
||||
TfLiteTensor* scaling_factors =
|
||||
GetTemporary(context, node, /*index=*/2);
|
||||
TfLiteTensor* float_weights_time =
|
||||
GetTemporary(context, node, /*index=*/3);
|
||||
TfLiteTensor* zero_points = GetTemporary(context, node, /*index=*/4);
|
||||
TfLiteTensor* row_sums = GetTemporary(context, node, /*index=*/5);
|
||||
TfLiteTensor* input_quantized;
|
||||
TF_LITE_ENSURE_OK(context, GetTemporarySafe(context, node, /*index=*/1,
|
||||
&input_quantized));
|
||||
TfLiteTensor* scaling_factors;
|
||||
TF_LITE_ENSURE_OK(context, GetTemporarySafe(context, node, /*index=*/2,
|
||||
&scaling_factors));
|
||||
TfLiteTensor* float_weights_time;
|
||||
TF_LITE_ENSURE_OK(context, GetTemporarySafe(context, node, /*index=*/3,
|
||||
&float_weights_time));
|
||||
TfLiteTensor* zero_points;
|
||||
TF_LITE_ENSURE_OK(context, GetTemporarySafe(context, node, /*index=*/4,
|
||||
&zero_points));
|
||||
TfLiteTensor* row_sums;
|
||||
TF_LITE_ENSURE_OK(
|
||||
context, GetTemporarySafe(context, node, /*index=*/5, &row_sums));
|
||||
// Dequantize weights time.
|
||||
// TODO(alanchiao): this dequantization initialization only needs to
|
||||
// happen once per model and should theoretically be placed in either
|
||||
@ -322,7 +356,9 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
||||
input->quantization.params);
|
||||
auto* output_params = reinterpret_cast<TfLiteAffineQuantization*>(
|
||||
output->quantization.params);
|
||||
TfLiteTensor* output_temp = GetTemporary(context, node, /*index=*/1);
|
||||
TfLiteTensor* output_temp;
|
||||
TF_LITE_ENSURE_OK(context, GetTemporarySafe(context, node, /*index=*/1,
|
||||
&output_temp));
|
||||
|
||||
// Currently supports only ReLU.
|
||||
// TODO(jianlijianli): support other activations.
|
||||
|
@ -49,9 +49,14 @@ TfLiteIntArray* MultiplyShapeDims(const TfLiteIntArray& shape,
|
||||
}
|
||||
|
||||
TfLiteStatus ResizeOutput(TfLiteContext* context, TfLiteNode* node) {
|
||||
const TfLiteTensor* input = GetInput(context, node, kInputTensor);
|
||||
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
|
||||
const TfLiteTensor* multipliers = GetInput(context, node, kInputMultipliers);
|
||||
const TfLiteTensor* input;
|
||||
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputTensor, &input));
|
||||
TfLiteTensor* output;
|
||||
TF_LITE_ENSURE_OK(context,
|
||||
GetOutputSafe(context, node, kOutputTensor, &output));
|
||||
const TfLiteTensor* multipliers;
|
||||
TF_LITE_ENSURE_OK(
|
||||
context, GetInputSafe(context, node, kInputMultipliers, &multipliers));
|
||||
|
||||
const int num_dimensions = NumDimensions(input);
|
||||
const int num_multipliers = NumElements(multipliers);
|
||||
@ -208,12 +213,17 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
TF_LITE_ENSURE_EQ(context, NumInputs(node), 2);
|
||||
TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
|
||||
|
||||
const TfLiteTensor* input = GetInput(context, node, kInputTensor);
|
||||
const TfLiteTensor* input;
|
||||
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputTensor, &input));
|
||||
|
||||
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
|
||||
TfLiteTensor* output;
|
||||
TF_LITE_ENSURE_OK(context,
|
||||
GetOutputSafe(context, node, kOutputTensor, &output));
|
||||
TF_LITE_ENSURE_TYPES_EQ(context, input->type, output->type);
|
||||
|
||||
const TfLiteTensor* multipliers = GetInput(context, node, kInputMultipliers);
|
||||
const TfLiteTensor* multipliers;
|
||||
TF_LITE_ENSURE_OK(
|
||||
context, GetInputSafe(context, node, kInputMultipliers, &multipliers));
|
||||
// Only int32 and int64 multipliers type is supported.
|
||||
if (multipliers->type != kTfLiteInt32 && multipliers->type != kTfLiteInt64) {
|
||||
context->ReportError(context,
|
||||
@ -231,9 +241,14 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
}
|
||||
|
||||
TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
||||
const TfLiteTensor* input = GetInput(context, node, kInputTensor);
|
||||
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
|
||||
const TfLiteTensor* multipliers = GetInput(context, node, kInputMultipliers);
|
||||
const TfLiteTensor* input;
|
||||
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputTensor, &input));
|
||||
TfLiteTensor* output;
|
||||
TF_LITE_ENSURE_OK(context,
|
||||
GetOutputSafe(context, node, kOutputTensor, &output));
|
||||
const TfLiteTensor* multipliers;
|
||||
TF_LITE_ENSURE_OK(
|
||||
context, GetInputSafe(context, node, kInputMultipliers, &multipliers));
|
||||
|
||||
if (IsDynamicTensor(output)) {
|
||||
TF_LITE_ENSURE_OK(context, ResizeOutput(context, node));
|
||||
|
@ -35,14 +35,16 @@ constexpr int kOutputIndexes = 1;
|
||||
|
||||
namespace {
|
||||
TfLiteStatus ResizeOutput(TfLiteContext* context, TfLiteNode* node) {
|
||||
const TfLiteTensor* top_k = GetInput(context, node, kInputTopK);
|
||||
const TfLiteTensor* top_k;
|
||||
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputTopK, &top_k));
|
||||
// INT32 number of top results is supported.
|
||||
TF_LITE_ENSURE_TYPES_EQ(context, top_k->type, kTfLiteInt32);
|
||||
// Check that the tensor contains only one value.
|
||||
TF_LITE_ENSURE_EQ(context, NumElements(top_k), 1);
|
||||
const int32 k = *GetTensorData<int32_t>(top_k);
|
||||
|
||||
const TfLiteTensor* input = GetInput(context, node, kInputTensor);
|
||||
const TfLiteTensor* input;
|
||||
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputTensor, &input));
|
||||
const int num_dimensions = NumDimensions(input);
|
||||
// Check that input has one or more dimensions.
|
||||
TF_LITE_ENSURE_MSG(context, input->dims->size >= 1,
|
||||
@ -59,8 +61,12 @@ TfLiteStatus ResizeOutput(TfLiteContext* context, TfLiteNode* node) {
|
||||
}
|
||||
output_indexes_shape->data[num_dimensions - 1] = k;
|
||||
output_values_shape->data[num_dimensions - 1] = k;
|
||||
TfLiteTensor* output_indexes = GetOutput(context, node, kOutputIndexes);
|
||||
TfLiteTensor* output_values = GetOutput(context, node, kOutputValues);
|
||||
TfLiteTensor* output_indexes;
|
||||
TF_LITE_ENSURE_OK(
|
||||
context, GetOutputSafe(context, node, kOutputIndexes, &output_indexes));
|
||||
TfLiteTensor* output_values;
|
||||
TF_LITE_ENSURE_OK(
|
||||
context, GetOutputSafe(context, node, kOutputValues, &output_values));
|
||||
// Force output types.
|
||||
output_indexes->type = kTfLiteInt32;
|
||||
output_values->type = input->type;
|
||||
@ -195,19 +201,27 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
TF_LITE_ENSURE_EQ(context, NumInputs(node), 2);
|
||||
TF_LITE_ENSURE_EQ(context, NumOutputs(node), 2);
|
||||
|
||||
const TfLiteTensor* input = GetInput(context, node, kInputTensor);
|
||||
TfLiteTensor* output_values = GetOutput(context, node, kOutputValues);
|
||||
const TfLiteTensor* input;
|
||||
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputTensor, &input));
|
||||
TfLiteTensor* output_values;
|
||||
TF_LITE_ENSURE_OK(
|
||||
context, GetOutputSafe(context, node, kOutputValues, &output_values));
|
||||
TF_LITE_ENSURE_TYPES_EQ(context, input->type, output_values->type);
|
||||
|
||||
const TfLiteTensor* top_k = GetInput(context, node, kInputTopK);
|
||||
const TfLiteTensor* top_k;
|
||||
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputTopK, &top_k));
|
||||
TF_LITE_ENSURE_TYPES_EQ(context, top_k->type, kTfLiteInt32);
|
||||
|
||||
// Set output dynamic if the input is not const.
|
||||
if (IsConstantTensor(top_k)) {
|
||||
TF_LITE_ENSURE_OK(context, ResizeOutput(context, node));
|
||||
} else {
|
||||
TfLiteTensor* output_indexes = GetOutput(context, node, kOutputIndexes);
|
||||
TfLiteTensor* output_values = GetOutput(context, node, kOutputValues);
|
||||
TfLiteTensor* output_indexes;
|
||||
TF_LITE_ENSURE_OK(
|
||||
context, GetOutputSafe(context, node, kOutputIndexes, &output_indexes));
|
||||
TfLiteTensor* output_values;
|
||||
TF_LITE_ENSURE_OK(
|
||||
context, GetOutputSafe(context, node, kOutputValues, &output_values));
|
||||
SetTensorToDynamic(output_indexes);
|
||||
SetTensorToDynamic(output_values);
|
||||
}
|
||||
@ -215,16 +229,22 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
}
|
||||
|
||||
TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
||||
TfLiteTensor* output_values = GetOutput(context, node, kOutputValues);
|
||||
TfLiteTensor* output_indexes = GetOutput(context, node, kOutputIndexes);
|
||||
TfLiteTensor* output_values;
|
||||
TF_LITE_ENSURE_OK(
|
||||
context, GetOutputSafe(context, node, kOutputValues, &output_values));
|
||||
TfLiteTensor* output_indexes;
|
||||
TF_LITE_ENSURE_OK(
|
||||
context, GetOutputSafe(context, node, kOutputIndexes, &output_indexes));
|
||||
if (IsDynamicTensor(output_values)) {
|
||||
TF_LITE_ENSURE_OK(context, ResizeOutput(context, node));
|
||||
}
|
||||
const TfLiteTensor* top_k = GetInput(context, node, kInputTopK);
|
||||
const TfLiteTensor* top_k;
|
||||
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputTopK, &top_k));
|
||||
const int32 k = top_k->data.i32[0];
|
||||
// The tensor can have more than 2 dimensions or even be a vector, the code
|
||||
// anyway calls the internal dimension as row;
|
||||
const TfLiteTensor* input = GetInput(context, node, kInputTensor);
|
||||
const TfLiteTensor* input;
|
||||
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputTensor, &input));
|
||||
const int32 row_size = input->dims->data[input->dims->size - 1];
|
||||
int32 num_rows = 1;
|
||||
for (int i = 0; i < input->dims->size - 1; ++i) {
|
||||
|
@ -250,13 +250,20 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
|
||||
|
||||
// Retrieve tensors
|
||||
const TfLiteTensor* output_shape =
|
||||
GetInput(context, node, kOutputShapeTensor);
|
||||
const TfLiteTensor* weights = GetInput(context, node, kWeightsTensor);
|
||||
const TfLiteTensor* input = GetInput(context, node, kDataInputTensor);
|
||||
const TfLiteTensor* output_shape;
|
||||
TF_LITE_ENSURE_OK(
|
||||
context, GetInputSafe(context, node, kOutputShapeTensor, &output_shape));
|
||||
const TfLiteTensor* weights;
|
||||
TF_LITE_ENSURE_OK(context,
|
||||
GetInputSafe(context, node, kWeightsTensor, &weights));
|
||||
const TfLiteTensor* input;
|
||||
TF_LITE_ENSURE_OK(context,
|
||||
GetInputSafe(context, node, kDataInputTensor, &input));
|
||||
const TfLiteTensor* bias = nullptr;
|
||||
|
||||
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
|
||||
TfLiteTensor* output;
|
||||
TF_LITE_ENSURE_OK(context,
|
||||
GetOutputSafe(context, node, kOutputTensor, &output));
|
||||
|
||||
// Tensor sanity checks
|
||||
TF_LITE_ENSURE_EQ(context, NumDimensions(output_shape), 1);
|
||||
@ -306,7 +313,9 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
TfLiteTensor* col2im = nullptr;
|
||||
if (data->has_col2im) {
|
||||
node->temporaries->data[data->col2im_index] = data->col2im_id;
|
||||
col2im = GetTemporary(context, node, user_data->col2im_index);
|
||||
TF_LITE_ENSURE_OK(
|
||||
context,
|
||||
GetTemporarySafe(context, node, user_data->col2im_index, &col2im));
|
||||
}
|
||||
|
||||
if (!IsConstantTensor(output_shape)) {
|
||||
@ -326,8 +335,11 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
if (data->weights_are_transposed) {
|
||||
node->temporaries->data[data->transposed_weights_index] =
|
||||
data->transposed_weights_id;
|
||||
TfLiteTensor* transposed_weights =
|
||||
GetTemporary(context, node, user_data->transposed_weights_index);
|
||||
TfLiteTensor* transposed_weights;
|
||||
TF_LITE_ENSURE_OK(
|
||||
context,
|
||||
GetTemporarySafe(context, node, user_data->transposed_weights_index,
|
||||
&transposed_weights));
|
||||
if (!IsConstantTensor(weights)) {
|
||||
SetTensorToDynamic(transposed_weights);
|
||||
} else {
|
||||
@ -339,8 +351,10 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
input->type == kTfLiteInt16) {
|
||||
node->temporaries->data[data->scratch_tensor_index] =
|
||||
data->scratch_tensor_id;
|
||||
TfLiteTensor* scratch_buffer =
|
||||
GetTemporary(context, node, data->scratch_tensor_index);
|
||||
TfLiteTensor* scratch_buffer;
|
||||
TF_LITE_ENSURE_OK(
|
||||
context, GetTemporarySafe(context, node, data->scratch_tensor_index,
|
||||
&scratch_buffer));
|
||||
if (input->type == kTfLiteInt16) {
|
||||
scratch_buffer->type = kTfLiteInt64;
|
||||
} else {
|
||||
@ -549,15 +563,22 @@ void EvalQuantizedPerChannel16x8(
|
||||
template <KernelType kernel_type>
|
||||
TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
||||
// Retrieve tensors (All should be allocated by now)
|
||||
const TfLiteTensor* output_shape =
|
||||
GetInput(context, node, kOutputShapeTensor);
|
||||
const TfLiteTensor* weights = GetInput(context, node, kWeightsTensor);
|
||||
const TfLiteTensor* input = GetInput(context, node, kDataInputTensor);
|
||||
const TfLiteTensor* output_shape;
|
||||
TF_LITE_ENSURE_OK(
|
||||
context, GetInputSafe(context, node, kOutputShapeTensor, &output_shape));
|
||||
const TfLiteTensor* weights;
|
||||
TF_LITE_ENSURE_OK(context,
|
||||
GetInputSafe(context, node, kWeightsTensor, &weights));
|
||||
const TfLiteTensor* input;
|
||||
TF_LITE_ENSURE_OK(context,
|
||||
GetInputSafe(context, node, kDataInputTensor, &input));
|
||||
const TfLiteTensor* bias =
|
||||
(NumInputs(node) == 4)
|
||||
? GetOptionalInputTensor(context, node, kBiasTensor)
|
||||
: nullptr;
|
||||
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
|
||||
TfLiteTensor* output;
|
||||
TF_LITE_ENSURE_OK(context,
|
||||
GetOutputSafe(context, node, kOutputTensor, &output));
|
||||
OpData* data = reinterpret_cast<OpData*>(node->user_data);
|
||||
TfLiteTensor* col2im = data->has_col2im
|
||||
? GetTemporary(context, node, data->col2im_index)
|
||||
@ -604,8 +625,10 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
||||
break;
|
||||
}
|
||||
case kTfLiteUInt8: {
|
||||
TfLiteTensor* scratch_buffer =
|
||||
GetTemporary(context, node, data->scratch_tensor_index);
|
||||
TfLiteTensor* scratch_buffer;
|
||||
TF_LITE_ENSURE_OK(
|
||||
context, GetTemporarySafe(context, node, data->scratch_tensor_index,
|
||||
&scratch_buffer));
|
||||
if (IsDynamicTensor(scratch_buffer)) {
|
||||
TF_LITE_ENSURE_OK(context,
|
||||
ResizeTensor(context, output_shape, scratch_buffer));
|
||||
@ -621,8 +644,10 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
||||
break;
|
||||
}
|
||||
case kTfLiteInt8: {
|
||||
TfLiteTensor* scratch_buffer =
|
||||
GetTemporary(context, node, data->scratch_tensor_index);
|
||||
TfLiteTensor* scratch_buffer;
|
||||
TF_LITE_ENSURE_OK(
|
||||
context, GetTemporarySafe(context, node, data->scratch_tensor_index,
|
||||
&scratch_buffer));
|
||||
if (IsDynamicTensor(scratch_buffer)) {
|
||||
TF_LITE_ENSURE_OK(context,
|
||||
ResizeTensor(context, output_shape, scratch_buffer));
|
||||
@ -636,8 +661,10 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
||||
break;
|
||||
}
|
||||
case kTfLiteInt16: {
|
||||
TfLiteTensor* scratch_buffer =
|
||||
GetTemporary(context, node, data->scratch_tensor_index);
|
||||
TfLiteTensor* scratch_buffer;
|
||||
TF_LITE_ENSURE_OK(
|
||||
context, GetTemporarySafe(context, node, data->scratch_tensor_index,
|
||||
&scratch_buffer));
|
||||
if (IsDynamicTensor(scratch_buffer)) {
|
||||
TF_LITE_ENSURE_OK(context,
|
||||
ResizeTensor(context, output_shape, scratch_buffer));
|
||||
|
@ -20,6 +20,7 @@ limitations under the License.
|
||||
#include "tensorflow/lite/c/builtin_op_data.h"
|
||||
#include "tensorflow/lite/c/common.h"
|
||||
#include "tensorflow/lite/kernels/cpu_backend_context.h"
|
||||
#include "tensorflow/lite/kernels/internal/compatibility.h"
|
||||
#include "tensorflow/lite/kernels/internal/kernel_utils.h"
|
||||
#include "tensorflow/lite/kernels/internal/tensor_utils.h"
|
||||
#include "tensorflow/lite/kernels/kernel_util.h"
|
||||
@ -88,14 +89,19 @@ TfLiteStatus CheckInputTensorDimensions(TfLiteContext* context,
|
||||
TF_LITE_ENSURE_EQ(context, input_to_input_weights->dims->data[1], n_input);
|
||||
}
|
||||
|
||||
const TfLiteTensor* input_to_forget_weights =
|
||||
GetInput(context, node, lstm::full::kInputToForgetWeightsTensor);
|
||||
const TfLiteTensor* input_to_forget_weights;
|
||||
TF_LITE_ENSURE_OK(
|
||||
context,
|
||||
GetInputSafe(context, node, lstm::full::kInputToForgetWeightsTensor,
|
||||
&input_to_forget_weights));
|
||||
TF_LITE_ENSURE_EQ(context, input_to_forget_weights->dims->size, 2);
|
||||
TF_LITE_ENSURE_EQ(context, input_to_forget_weights->dims->data[0], n_cell);
|
||||
TF_LITE_ENSURE_EQ(context, input_to_forget_weights->dims->data[1], n_input);
|
||||
|
||||
const TfLiteTensor* input_to_cell_weights =
|
||||
GetInput(context, node, lstm::full::kInputToCellWeightsTensor);
|
||||
const TfLiteTensor* input_to_cell_weights;
|
||||
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node,
|
||||
lstm::full::kInputToCellWeightsTensor,
|
||||
&input_to_cell_weights));
|
||||
TF_LITE_ENSURE_EQ(context, input_to_cell_weights->dims->size, 2);
|
||||
TF_LITE_ENSURE_EQ(context, input_to_cell_weights->dims->data[0], n_cell);
|
||||
TF_LITE_ENSURE_EQ(context, input_to_cell_weights->dims->data[1], n_input);
|
||||
@ -110,16 +116,22 @@ TfLiteStatus CheckInputTensorDimensions(TfLiteContext* context,
|
||||
n_output);
|
||||
}
|
||||
|
||||
const TfLiteTensor* recurrent_to_forget_weights =
|
||||
GetInput(context, node, lstm::full::kRecurrentToForgetWeightsTensor);
|
||||
const TfLiteTensor* recurrent_to_forget_weights;
|
||||
TF_LITE_ENSURE_OK(
|
||||
context,
|
||||
GetInputSafe(context, node, lstm::full::kRecurrentToForgetWeightsTensor,
|
||||
&recurrent_to_forget_weights));
|
||||
TF_LITE_ENSURE_EQ(context, recurrent_to_forget_weights->dims->size, 2);
|
||||
TF_LITE_ENSURE_EQ(context, recurrent_to_forget_weights->dims->data[0],
|
||||
n_cell);
|
||||
TF_LITE_ENSURE_EQ(context, recurrent_to_forget_weights->dims->data[1],
|
||||
n_output);
|
||||
|
||||
const TfLiteTensor* recurrent_to_cell_weights =
|
||||
GetInput(context, node, lstm::full::kRecurrentToCellWeightsTensor);
|
||||
const TfLiteTensor* recurrent_to_cell_weights;
|
||||
TF_LITE_ENSURE_OK(
|
||||
context,
|
||||
GetInputSafe(context, node, lstm::full::kRecurrentToCellWeightsTensor,
|
||||
&recurrent_to_cell_weights));
|
||||
TF_LITE_ENSURE_EQ(context, recurrent_to_cell_weights->dims->size, 2);
|
||||
TF_LITE_ENSURE_EQ(context, recurrent_to_cell_weights->dims->data[0], n_cell);
|
||||
TF_LITE_ENSURE_EQ(context, recurrent_to_cell_weights->dims->data[1],
|
||||
@ -176,18 +188,24 @@ TfLiteStatus CheckInputTensorDimensions(TfLiteContext* context,
|
||||
TF_LITE_ENSURE_EQ(context, input_gate_bias->dims->data[0], n_cell);
|
||||
}
|
||||
|
||||
const TfLiteTensor* forget_gate_bias =
|
||||
GetInput(context, node, lstm::full::kForgetGateBiasTensor);
|
||||
const TfLiteTensor* forget_gate_bias;
|
||||
TF_LITE_ENSURE_OK(
|
||||
context, GetInputSafe(context, node, lstm::full::kForgetGateBiasTensor,
|
||||
&forget_gate_bias));
|
||||
TF_LITE_ENSURE_EQ(context, forget_gate_bias->dims->size, 1);
|
||||
TF_LITE_ENSURE_EQ(context, forget_gate_bias->dims->data[0], n_cell);
|
||||
|
||||
const TfLiteTensor* cell_gate_bias =
|
||||
GetInput(context, node, lstm::full::kCellGateBiasTensor);
|
||||
const TfLiteTensor* cell_gate_bias;
|
||||
TF_LITE_ENSURE_OK(context,
|
||||
GetInputSafe(context, node, lstm::full::kCellGateBiasTensor,
|
||||
&cell_gate_bias));
|
||||
TF_LITE_ENSURE_EQ(context, cell_gate_bias->dims->size, 1);
|
||||
TF_LITE_ENSURE_EQ(context, cell_gate_bias->dims->data[0], n_cell);
|
||||
|
||||
const TfLiteTensor* output_gate_bias =
|
||||
GetInput(context, node, lstm::full::kOutputGateBiasTensor);
|
||||
const TfLiteTensor* output_gate_bias;
|
||||
TF_LITE_ENSURE_OK(
|
||||
context, GetInputSafe(context, node, lstm::full::kOutputGateBiasTensor,
|
||||
&output_gate_bias));
|
||||
TF_LITE_ENSURE_EQ(context, output_gate_bias->dims->size, 1);
|
||||
TF_LITE_ENSURE_EQ(context, output_gate_bias->dims->data[0], n_cell);
|
||||
|
||||
@ -229,27 +247,33 @@ TfLiteStatus CheckInputTensorDimensions(TfLiteContext* context,
|
||||
kTfLiteFloat32);
|
||||
}
|
||||
|
||||
const TfLiteTensor* forget_layer_norm_coefficients =
|
||||
GetInput(context, node, lstm::full::kForgetLayerNormCoefficientsTensor);
|
||||
TF_LITE_ENSURE(context, forget_layer_norm_coefficients != nullptr);
|
||||
const TfLiteTensor* forget_layer_norm_coefficients;
|
||||
TF_LITE_ENSURE_OK(
|
||||
context, GetInputSafe(context, node,
|
||||
lstm::full::kForgetLayerNormCoefficientsTensor,
|
||||
&forget_layer_norm_coefficients));
|
||||
TF_LITE_ENSURE_EQ(context, forget_layer_norm_coefficients->dims->size, 1);
|
||||
TF_LITE_ENSURE_EQ(context, forget_layer_norm_coefficients->dims->data[0],
|
||||
n_cell);
|
||||
TF_LITE_ENSURE_TYPES_EQ(context, forget_layer_norm_coefficients->type,
|
||||
kTfLiteFloat32);
|
||||
|
||||
const TfLiteTensor* cell_layer_norm_coefficients =
|
||||
GetInput(context, node, lstm::full::kCellLayerNormCoefficientsTensor);
|
||||
TF_LITE_ENSURE(context, cell_layer_norm_coefficients != nullptr);
|
||||
const TfLiteTensor* cell_layer_norm_coefficients;
|
||||
TF_LITE_ENSURE_OK(context,
|
||||
GetInputSafe(context, node,
|
||||
lstm::full::kCellLayerNormCoefficientsTensor,
|
||||
&cell_layer_norm_coefficients));
|
||||
TF_LITE_ENSURE_EQ(context, cell_layer_norm_coefficients->dims->size, 1);
|
||||
TF_LITE_ENSURE_EQ(context, cell_layer_norm_coefficients->dims->data[0],
|
||||
n_cell);
|
||||
TF_LITE_ENSURE_TYPES_EQ(context, cell_layer_norm_coefficients->type,
|
||||
kTfLiteFloat32);
|
||||
|
||||
const TfLiteTensor* output_layer_norm_coefficients =
|
||||
GetInput(context, node, lstm::full::kOutputLayerNormCoefficientsTensor);
|
||||
TF_LITE_ENSURE(context, output_layer_norm_coefficients != nullptr);
|
||||
const TfLiteTensor* output_layer_norm_coefficients;
|
||||
TF_LITE_ENSURE_OK(
|
||||
context, GetInputSafe(context, node,
|
||||
lstm::full::kOutputLayerNormCoefficientsTensor,
|
||||
&output_layer_norm_coefficients));
|
||||
TF_LITE_ENSURE_EQ(context, output_layer_norm_coefficients->dims->size, 1);
|
||||
TF_LITE_ENSURE_EQ(context, output_layer_norm_coefficients->dims->data[0],
|
||||
n_cell);
|
||||
@ -291,7 +315,9 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
|
||||
// Inferring batch size, number of outputs and sequence length and
|
||||
// number of cells from the input tensors.
|
||||
const TfLiteTensor* input = GetInput(context, node, lstm::full::kInputTensor);
|
||||
const TfLiteTensor* input;
|
||||
TF_LITE_ENSURE_OK(
|
||||
context, GetInputSafe(context, node, lstm::full::kInputTensor, &input));
|
||||
TF_LITE_ENSURE_TYPES_EQ(context, input->type, kTfLiteFloat32);
|
||||
TF_LITE_ENSURE(context, input->dims->size > 1);
|
||||
const auto* params =
|
||||
@ -301,14 +327,20 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
const int n_batch = time_major ? input->dims->data[1] : input->dims->data[0];
|
||||
const int n_input = input->dims->data[2];
|
||||
|
||||
const TfLiteTensor* input_to_output_weights =
|
||||
GetInput(context, node, lstm::full::kInputToOutputWeightsTensor);
|
||||
const TfLiteTensor* input_to_output_weights;
|
||||
TF_LITE_ENSURE_OK(
|
||||
context,
|
||||
GetInputSafe(context, node, lstm::full::kInputToOutputWeightsTensor,
|
||||
&input_to_output_weights));
|
||||
const int n_cell = input_to_output_weights->dims->data[0];
|
||||
TF_LITE_ENSURE_EQ(context, input_to_output_weights->dims->size, 2);
|
||||
TF_LITE_ENSURE_EQ(context, input_to_output_weights->dims->data[1], n_input);
|
||||
|
||||
const TfLiteTensor* recurrent_to_output_weights =
|
||||
GetInput(context, node, lstm::full::kRecurrentToOutputWeightsTensor);
|
||||
const TfLiteTensor* recurrent_to_output_weights;
|
||||
TF_LITE_ENSURE_OK(
|
||||
context,
|
||||
GetInputSafe(context, node, lstm::full::kRecurrentToOutputWeightsTensor,
|
||||
&recurrent_to_output_weights));
|
||||
TF_LITE_ENSURE_EQ(context, recurrent_to_output_weights->dims->size, 2);
|
||||
TF_LITE_ENSURE_EQ(context, recurrent_to_output_weights->dims->data[0],
|
||||
n_cell);
|
||||
@ -320,7 +352,9 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
n_cell, is_layer_norm_lstm));
|
||||
|
||||
// Get the pointer to output, output_state and cell_state buffer tensors.
|
||||
TfLiteTensor* output = GetOutput(context, node, lstm::full::kOutputTensor);
|
||||
TfLiteTensor* output;
|
||||
TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node,
|
||||
lstm::full::kOutputTensor, &output));
|
||||
|
||||
TfLiteTensor* output_state =
|
||||
GetVariableInput(context, node, lstm::full::kOutputStateTensor);
|
||||
@ -351,7 +385,9 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
scratch_tensor_index + kScratchBuffer;
|
||||
|
||||
// Create a scratch buffer tensor.
|
||||
TfLiteTensor* scratch_buffer = GetTemporary(context, node, kScratchBuffer);
|
||||
TfLiteTensor* scratch_buffer;
|
||||
TF_LITE_ENSURE_OK(context, GetTemporarySafe(context, node, kScratchBuffer,
|
||||
&scratch_buffer));
|
||||
scratch_buffer->type = input->type;
|
||||
scratch_buffer->allocation_type = kTfLiteArenaRw;
|
||||
|
||||
@ -376,8 +412,9 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
// output_state and cell_state tensors.
|
||||
node->temporaries->data[kInputQuantized] =
|
||||
scratch_tensor_index + kInputQuantized;
|
||||
TfLiteTensor* input_quantized =
|
||||
GetTemporary(context, node, kInputQuantized);
|
||||
TfLiteTensor* input_quantized;
|
||||
TF_LITE_ENSURE_OK(context, GetTemporarySafe(context, node, kInputQuantized,
|
||||
&input_quantized));
|
||||
input_quantized->type = input_to_output_weights->type;
|
||||
input_quantized->allocation_type = kTfLiteArenaRw;
|
||||
if (!TfLiteIntArrayEqual(input_quantized->dims, input->dims)) {
|
||||
@ -387,8 +424,10 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
}
|
||||
node->temporaries->data[kOutputStateQuantized] =
|
||||
scratch_tensor_index + kOutputStateQuantized;
|
||||
TfLiteTensor* output_state_quantized =
|
||||
GetTemporary(context, node, kOutputStateQuantized);
|
||||
TfLiteTensor* output_state_quantized;
|
||||
TF_LITE_ENSURE_OK(context,
|
||||
GetTemporarySafe(context, node, kOutputStateQuantized,
|
||||
&output_state_quantized));
|
||||
output_state_quantized->type = input_to_output_weights->type;
|
||||
output_state_quantized->allocation_type = kTfLiteArenaRw;
|
||||
if (!TfLiteIntArrayEqual(output_state_quantized->dims,
|
||||
@ -401,8 +440,10 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
}
|
||||
node->temporaries->data[kCellStateQuantized] =
|
||||
scratch_tensor_index + kCellStateQuantized;
|
||||
TfLiteTensor* cell_state_quantized =
|
||||
GetTemporary(context, node, kCellStateQuantized);
|
||||
TfLiteTensor* cell_state_quantized;
|
||||
TF_LITE_ENSURE_OK(context,
|
||||
GetTemporarySafe(context, node, kCellStateQuantized,
|
||||
&cell_state_quantized));
|
||||
cell_state_quantized->type = input_to_output_weights->type;
|
||||
cell_state_quantized->allocation_type = kTfLiteArenaRw;
|
||||
if (!TfLiteIntArrayEqual(cell_state_quantized->dims, cell_state->dims)) {
|
||||
@ -420,7 +461,10 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
// the scaling factor of the matrix).
|
||||
node->temporaries->data[kInputScalingFactors] =
|
||||
op_data->scratch_tensor_index + kInputScalingFactors;
|
||||
TfLiteTensor* input_sf = GetTemporary(context, node, kInputScalingFactors);
|
||||
TfLiteTensor* input_sf;
|
||||
TF_LITE_ENSURE_OK(
|
||||
context,
|
||||
GetTemporarySafe(context, node, kInputScalingFactors, &input_sf));
|
||||
input_sf->type = kTfLiteFloat32;
|
||||
input_sf->allocation_type = kTfLiteArenaRw;
|
||||
int scaling_dims[1] = {n_batch};
|
||||
@ -432,8 +476,10 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
}
|
||||
node->temporaries->data[kOutputStateScalingFactors] =
|
||||
op_data->scratch_tensor_index + kOutputStateScalingFactors;
|
||||
TfLiteTensor* output_state_sf =
|
||||
GetTemporary(context, node, kOutputStateScalingFactors);
|
||||
TfLiteTensor* output_state_sf;
|
||||
TF_LITE_ENSURE_OK(
|
||||
context, GetTemporarySafe(context, node, kOutputStateScalingFactors,
|
||||
&output_state_sf));
|
||||
output_state_sf->type = kTfLiteFloat32;
|
||||
output_state_sf->allocation_type = kTfLiteArenaRw;
|
||||
if (!TfLiteIntArrayEqualsArray(output_state_sf->dims, 1, scaling_dims)) {
|
||||
@ -444,8 +490,10 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
}
|
||||
node->temporaries->data[kProductScalingFactors] =
|
||||
scratch_tensor_index + kProductScalingFactors;
|
||||
TfLiteTensor* prod_scaling_factors =
|
||||
GetTemporary(context, node, kProductScalingFactors);
|
||||
TfLiteTensor* prod_scaling_factors;
|
||||
TF_LITE_ENSURE_OK(context,
|
||||
GetTemporarySafe(context, node, kProductScalingFactors,
|
||||
&prod_scaling_factors));
|
||||
prod_scaling_factors->type = kTfLiteFloat32;
|
||||
prod_scaling_factors->allocation_type = kTfLiteArenaRw;
|
||||
if (!TfLiteIntArrayEqualsArray(prod_scaling_factors->dims, 1,
|
||||
@ -461,8 +509,10 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
// this is used for diagonal matrices, only need to store n_cell values.
|
||||
node->temporaries->data[kRecoveredCellWeights] =
|
||||
scratch_tensor_index + kRecoveredCellWeights;
|
||||
TfLiteTensor* recovered_cell_weights =
|
||||
GetTemporary(context, node, kRecoveredCellWeights);
|
||||
TfLiteTensor* recovered_cell_weights;
|
||||
TF_LITE_ENSURE_OK(context,
|
||||
GetTemporarySafe(context, node, kRecoveredCellWeights,
|
||||
&recovered_cell_weights));
|
||||
recovered_cell_weights->type = kTfLiteFloat32;
|
||||
recovered_cell_weights->allocation_type = kTfLiteArenaRw;
|
||||
int recovered_cell_dims[1] = {n_cell};
|
||||
@ -478,7 +528,9 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
// Allocate a temporary tensor to store the accumulated int32 values.
|
||||
node->temporaries->data[kAccumScratch] =
|
||||
scratch_tensor_index + kAccumScratch;
|
||||
TfLiteTensor* accum_scratch = GetTemporary(context, node, kAccumScratch);
|
||||
TfLiteTensor* accum_scratch;
|
||||
TF_LITE_ENSURE_OK(context, GetTemporarySafe(context, node, kAccumScratch,
|
||||
&accum_scratch));
|
||||
accum_scratch->type = kTfLiteInt32;
|
||||
accum_scratch->allocation_type = kTfLiteArenaRw;
|
||||
int accum_scratch_dims[2] = {n_cell, n_batch};
|
||||
@ -492,7 +544,9 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
}
|
||||
node->temporaries->data[kInputZeroPoints] =
|
||||
op_data->scratch_tensor_index + kInputZeroPoints;
|
||||
TfLiteTensor* input_zp = GetTemporary(context, node, kInputZeroPoints);
|
||||
TfLiteTensor* input_zp;
|
||||
TF_LITE_ENSURE_OK(
|
||||
context, GetTemporarySafe(context, node, kInputZeroPoints, &input_zp));
|
||||
input_zp->type = kTfLiteFloat32;
|
||||
input_zp->allocation_type = kTfLiteArenaRw;
|
||||
if (!TfLiteIntArrayEqualsArray(input_zp->dims, 1, scaling_dims)) {
|
||||
@ -503,8 +557,10 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
}
|
||||
node->temporaries->data[kOutputStateZeroPoints] =
|
||||
op_data->scratch_tensor_index + kOutputStateZeroPoints;
|
||||
TfLiteTensor* output_state_zp =
|
||||
GetTemporary(context, node, kOutputStateZeroPoints);
|
||||
TfLiteTensor* output_state_zp;
|
||||
TF_LITE_ENSURE_OK(context,
|
||||
GetTemporarySafe(context, node, kOutputStateZeroPoints,
|
||||
&output_state_zp));
|
||||
output_state_zp->type = kTfLiteFloat32;
|
||||
output_state_zp->allocation_type = kTfLiteArenaRw;
|
||||
if (!TfLiteIntArrayEqualsArray(output_state_zp->dims, 1, scaling_dims)) {
|
||||
@ -514,7 +570,9 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
output_state_zp_size));
|
||||
}
|
||||
node->temporaries->data[kRowSums] = scratch_tensor_index + kRowSums;
|
||||
TfLiteTensor* row_sums = GetTemporary(context, node, kRowSums);
|
||||
TfLiteTensor* row_sums;
|
||||
TF_LITE_ENSURE_OK(context,
|
||||
GetTemporarySafe(context, node, kRowSums, &row_sums));
|
||||
row_sums->type = kTfLiteInt32;
|
||||
row_sums->allocation_type = kTfLiteArenaRwPersistent;
|
||||
int row_sums_rows = use_cifg ? 6 : 8;
|
||||
@ -542,25 +600,44 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
||||
const OpData* op_data = reinterpret_cast<OpData*>(node->user_data);
|
||||
const bool is_layer_norm_lstm = op_data->is_layer_norm_lstm;
|
||||
const bool time_major = params->time_major;
|
||||
const TfLiteTensor* input = GetInput(context, node, lstm::full::kInputTensor);
|
||||
const TfLiteTensor* input;
|
||||
TF_LITE_ENSURE_OK(
|
||||
context, GetInputSafe(context, node, lstm::full::kInputTensor, &input));
|
||||
|
||||
const TfLiteTensor* input_to_input_weights = GetOptionalInputTensor(
|
||||
context, node, lstm::full::kInputToInputWeightsTensor);
|
||||
const TfLiteTensor* input_to_forget_weights =
|
||||
GetInput(context, node, lstm::full::kInputToForgetWeightsTensor);
|
||||
const TfLiteTensor* input_to_cell_weights =
|
||||
GetInput(context, node, lstm::full::kInputToCellWeightsTensor);
|
||||
const TfLiteTensor* input_to_output_weights =
|
||||
GetInput(context, node, lstm::full::kInputToOutputWeightsTensor);
|
||||
const TfLiteTensor* input_to_forget_weights;
|
||||
TF_LITE_ENSURE_OK(
|
||||
context,
|
||||
GetInputSafe(context, node, lstm::full::kInputToForgetWeightsTensor,
|
||||
&input_to_forget_weights));
|
||||
const TfLiteTensor* input_to_cell_weights;
|
||||
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node,
|
||||
lstm::full::kInputToCellWeightsTensor,
|
||||
&input_to_cell_weights));
|
||||
const TfLiteTensor* input_to_output_weights;
|
||||
TF_LITE_ENSURE_OK(
|
||||
context,
|
||||
GetInputSafe(context, node, lstm::full::kInputToOutputWeightsTensor,
|
||||
&input_to_output_weights));
|
||||
|
||||
const TfLiteTensor* recurrent_to_input_weights = GetOptionalInputTensor(
|
||||
context, node, lstm::full::kRecurrentToInputWeightsTensor);
|
||||
const TfLiteTensor* recurrent_to_forget_weights =
|
||||
GetInput(context, node, lstm::full::kRecurrentToForgetWeightsTensor);
|
||||
const TfLiteTensor* recurrent_to_cell_weights =
|
||||
GetInput(context, node, lstm::full::kRecurrentToCellWeightsTensor);
|
||||
const TfLiteTensor* recurrent_to_output_weights =
|
||||
GetInput(context, node, lstm::full::kRecurrentToOutputWeightsTensor);
|
||||
const TfLiteTensor* recurrent_to_forget_weights;
|
||||
TF_LITE_ENSURE_OK(
|
||||
context,
|
||||
GetInputSafe(context, node, lstm::full::kRecurrentToForgetWeightsTensor,
|
||||
&recurrent_to_forget_weights));
|
||||
const TfLiteTensor* recurrent_to_cell_weights;
|
||||
TF_LITE_ENSURE_OK(
|
||||
context,
|
||||
GetInputSafe(context, node, lstm::full::kRecurrentToCellWeightsTensor,
|
||||
&recurrent_to_cell_weights));
|
||||
const TfLiteTensor* recurrent_to_output_weights;
|
||||
TF_LITE_ENSURE_OK(
|
||||
context,
|
||||
GetInputSafe(context, node, lstm::full::kRecurrentToOutputWeightsTensor,
|
||||
&recurrent_to_output_weights));
|
||||
|
||||
const TfLiteTensor* cell_to_input_weights = GetOptionalInputTensor(
|
||||
context, node, lstm::full::kCellToInputWeightsTensor);
|
||||
@ -571,12 +648,18 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
||||
|
||||
const TfLiteTensor* input_gate_bias =
|
||||
GetOptionalInputTensor(context, node, lstm::full::kInputGateBiasTensor);
|
||||
const TfLiteTensor* forget_gate_bias =
|
||||
GetInput(context, node, lstm::full::kForgetGateBiasTensor);
|
||||
const TfLiteTensor* cell_gate_bias =
|
||||
GetInput(context, node, lstm::full::kCellGateBiasTensor);
|
||||
const TfLiteTensor* output_gate_bias =
|
||||
GetInput(context, node, lstm::full::kOutputGateBiasTensor);
|
||||
const TfLiteTensor* forget_gate_bias;
|
||||
TF_LITE_ENSURE_OK(
|
||||
context, GetInputSafe(context, node, lstm::full::kForgetGateBiasTensor,
|
||||
&forget_gate_bias));
|
||||
const TfLiteTensor* cell_gate_bias;
|
||||
TF_LITE_ENSURE_OK(context,
|
||||
GetInputSafe(context, node, lstm::full::kCellGateBiasTensor,
|
||||
&cell_gate_bias));
|
||||
const TfLiteTensor* output_gate_bias;
|
||||
TF_LITE_ENSURE_OK(
|
||||
context, GetInputSafe(context, node, lstm::full::kOutputGateBiasTensor,
|
||||
&output_gate_bias));
|
||||
|
||||
const TfLiteTensor* projection_weights = GetOptionalInputTensor(
|
||||
context, node, lstm::full::kProjectionWeightsTensor);
|
||||
@ -584,14 +667,16 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
||||
GetOptionalInputTensor(context, node, lstm::full::kProjectionBiasTensor);
|
||||
|
||||
// Index the scratch buffers pointers to the global scratch buffer.
|
||||
TfLiteTensor* scratch_buffer = GetTemporary(context, node, kScratchBuffer);
|
||||
TfLiteTensor* scratch_buffer;
|
||||
TF_LITE_ENSURE_OK(context, GetTemporarySafe(context, node, kScratchBuffer,
|
||||
&scratch_buffer));
|
||||
|
||||
TfLiteTensor* output_state =
|
||||
GetVariableInput(context, node, lstm::full::kOutputStateTensor);
|
||||
TF_LITE_ENSURE(context, output_state != nullptr);
|
||||
TFLITE_DCHECK(output_state != nullptr);
|
||||
TfLiteTensor* cell_state =
|
||||
GetVariableInput(context, node, lstm::full::kCellStateTensor);
|
||||
TF_LITE_ENSURE(context, cell_state != nullptr);
|
||||
TFLITE_DCHECK(cell_state != nullptr);
|
||||
|
||||
const TfLiteTensor* input_layer_norm_coefficients =
|
||||
is_layer_norm_lstm
|
||||
@ -614,7 +699,9 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
||||
lstm::full::kOutputLayerNormCoefficientsTensor)
|
||||
: nullptr;
|
||||
|
||||
TfLiteTensor* output = GetOutput(context, node, lstm::full::kOutputTensor);
|
||||
TfLiteTensor* output;
|
||||
TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node,
|
||||
lstm::full::kOutputTensor, &output));
|
||||
|
||||
// Copy out the LSTM specific params so they can be passed in the function.
|
||||
TfLiteLSTMParams lstm_params;
|
||||
@ -647,7 +734,9 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
||||
case kTfLiteUInt8:
|
||||
case kTfLiteInt8: {
|
||||
OpData* op_data = reinterpret_cast<OpData*>(node->user_data);
|
||||
TfLiteTensor* row_sums = GetTemporary(context, node, kRowSums);
|
||||
TfLiteTensor* row_sums;
|
||||
TF_LITE_ENSURE_OK(context,
|
||||
GetTemporarySafe(context, node, kRowSums, &row_sums));
|
||||
const int row_sums_size = row_sums->dims->data[0];
|
||||
return lstm_eval::EvalHybrid(
|
||||
input, input_to_input_weights,
|
||||
|
@ -61,13 +61,20 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
TF_LITE_ENSURE_EQ(context, node->inputs->size, 5);
|
||||
TF_LITE_ENSURE_EQ(context, node->outputs->size, 1);
|
||||
|
||||
const TfLiteTensor* input = GetInput(context, node, kInputTensor);
|
||||
const TfLiteTensor* input_weights = GetInput(context, node, kWeightsTensor);
|
||||
const TfLiteTensor* recurrent_weights =
|
||||
GetInput(context, node, kRecurrentWeightsTensor);
|
||||
const TfLiteTensor* bias = GetInput(context, node, kBiasTensor);
|
||||
const TfLiteTensor* hidden_state =
|
||||
GetInput(context, node, kHiddenStateTensor);
|
||||
const TfLiteTensor* input;
|
||||
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputTensor, &input));
|
||||
const TfLiteTensor* input_weights;
|
||||
TF_LITE_ENSURE_OK(
|
||||
context, GetInputSafe(context, node, kWeightsTensor, &input_weights));
|
||||
const TfLiteTensor* recurrent_weights;
|
||||
TF_LITE_ENSURE_OK(
|
||||
context,
|
||||
GetInputSafe(context, node, kRecurrentWeightsTensor, &recurrent_weights));
|
||||
const TfLiteTensor* bias;
|
||||
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kBiasTensor, &bias));
|
||||
const TfLiteTensor* hidden_state;
|
||||
TF_LITE_ENSURE_OK(
|
||||
context, GetInputSafe(context, node, kHiddenStateTensor, &hidden_state));
|
||||
|
||||
// Check all the parameters of tensor match within themselves and match the
|
||||
// input configuration.
|
||||
@ -92,7 +99,9 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
TF_LITE_ENSURE_EQ(context, hidden_state->dims->data[0], batch_size);
|
||||
TF_LITE_ENSURE_EQ(context, hidden_state->dims->data[1], num_units);
|
||||
|
||||
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
|
||||
TfLiteTensor* output;
|
||||
TF_LITE_ENSURE_OK(context,
|
||||
GetOutputSafe(context, node, kOutputTensor, &output));
|
||||
|
||||
// Resize output.
|
||||
TfLiteIntArray* output_size_array = TfLiteIntArrayCreate(3);
|
||||
@ -112,7 +121,9 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
TfLiteIntArrayFree(node->temporaries);
|
||||
node->temporaries = TfLiteIntArrayCreate(6);
|
||||
node->temporaries->data[0] = op_data->scratch_tensor_index;
|
||||
TfLiteTensor* input_quantized = GetTemporary(context, node, /*index=*/0);
|
||||
TfLiteTensor* input_quantized;
|
||||
TF_LITE_ENSURE_OK(context, GetTemporarySafe(context, node, /*index=*/0,
|
||||
&input_quantized));
|
||||
input_quantized->type = input_weights->type;
|
||||
input_quantized->allocation_type = kTfLiteArenaRw;
|
||||
if (!TfLiteIntArrayEqual(input_quantized->dims, input->dims)) {
|
||||
@ -121,8 +132,9 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
input_quantized_size));
|
||||
}
|
||||
node->temporaries->data[1] = op_data->scratch_tensor_index + 1;
|
||||
TfLiteTensor* hidden_state_quantized =
|
||||
GetTemporary(context, node, /*index=*/1);
|
||||
TfLiteTensor* hidden_state_quantized;
|
||||
TF_LITE_ENSURE_OK(context, GetTemporarySafe(context, node, /*index=*/1,
|
||||
&hidden_state_quantized));
|
||||
hidden_state_quantized->type = input_weights->type;
|
||||
hidden_state_quantized->allocation_type = kTfLiteArenaRw;
|
||||
if (!TfLiteIntArrayEqual(hidden_state_quantized->dims,
|
||||
@ -134,7 +146,9 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
hidden_state_quantized_size));
|
||||
}
|
||||
node->temporaries->data[2] = op_data->scratch_tensor_index + 2;
|
||||
TfLiteTensor* scaling_factors = GetTemporary(context, node, /*index=*/2);
|
||||
TfLiteTensor* scaling_factors;
|
||||
TF_LITE_ENSURE_OK(context, GetTemporarySafe(context, node, /*index=*/2,
|
||||
&scaling_factors));
|
||||
scaling_factors->type = kTfLiteFloat32;
|
||||
scaling_factors->allocation_type = kTfLiteArenaRw;
|
||||
int scaling_dims[1] = {batch_size};
|
||||
@ -145,7 +159,9 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
scaling_factors_size));
|
||||
}
|
||||
node->temporaries->data[3] = op_data->scratch_tensor_index + 3;
|
||||
TfLiteTensor* accum_scratch = GetTemporary(context, node, /*index=*/3);
|
||||
TfLiteTensor* accum_scratch;
|
||||
TF_LITE_ENSURE_OK(
|
||||
context, GetTemporarySafe(context, node, /*index=*/3, &accum_scratch));
|
||||
accum_scratch->type = kTfLiteInt32;
|
||||
accum_scratch->allocation_type = kTfLiteArenaRw;
|
||||
int accum_scratch_dims[2] = {num_units, batch_size};
|
||||
@ -158,7 +174,9 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
accum_scratch_size));
|
||||
}
|
||||
node->temporaries->data[4] = op_data->scratch_tensor_index + 4;
|
||||
TfLiteTensor* zero_points = GetTemporary(context, node, /*index=*/4);
|
||||
TfLiteTensor* zero_points;
|
||||
TF_LITE_ENSURE_OK(
|
||||
context, GetTemporarySafe(context, node, /*index=*/4, &zero_points));
|
||||
zero_points->type = kTfLiteInt32;
|
||||
zero_points->allocation_type = kTfLiteArenaRw;
|
||||
int zero_points_dims[1] = {batch_size};
|
||||
@ -169,7 +187,9 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
zero_points_size));
|
||||
}
|
||||
node->temporaries->data[5] = op_data->scratch_tensor_index + 5;
|
||||
TfLiteTensor* row_sums = GetTemporary(context, node, /*index=*/5);
|
||||
TfLiteTensor* row_sums;
|
||||
TF_LITE_ENSURE_OK(context,
|
||||
GetTemporarySafe(context, node, /*index=*/5, &row_sums));
|
||||
row_sums->type = kTfLiteInt32;
|
||||
row_sums->allocation_type = kTfLiteArenaRwPersistent;
|
||||
int row_sums_dims[2] = {2, num_units};
|
||||
@ -335,15 +355,24 @@ TfLiteStatus EvalHybrid(
|
||||
|
||||
TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
||||
auto* params = reinterpret_cast<TfLiteSequenceRNNParams*>(node->builtin_data);
|
||||
const TfLiteTensor* input = GetInput(context, node, kInputTensor);
|
||||
const TfLiteTensor* input_weights = GetInput(context, node, kWeightsTensor);
|
||||
const TfLiteTensor* recurrent_weights =
|
||||
GetInput(context, node, kRecurrentWeightsTensor);
|
||||
const TfLiteTensor* bias = GetInput(context, node, kBiasTensor);
|
||||
const TfLiteTensor* input;
|
||||
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputTensor, &input));
|
||||
const TfLiteTensor* input_weights;
|
||||
TF_LITE_ENSURE_OK(
|
||||
context, GetInputSafe(context, node, kWeightsTensor, &input_weights));
|
||||
const TfLiteTensor* recurrent_weights;
|
||||
TF_LITE_ENSURE_OK(
|
||||
context,
|
||||
GetInputSafe(context, node, kRecurrentWeightsTensor, &recurrent_weights));
|
||||
const TfLiteTensor* bias;
|
||||
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kBiasTensor, &bias));
|
||||
// The hidden_state is a variable input tensor that can be modified.
|
||||
TfLiteTensor* hidden_state =
|
||||
const_cast<TfLiteTensor*>(GetInput(context, node, kHiddenStateTensor));
|
||||
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
|
||||
GetVariableInput(context, node, kHiddenStateTensor);
|
||||
TF_LITE_ENSURE(context, hidden_state != nullptr);
|
||||
TfLiteTensor* output;
|
||||
TF_LITE_ENSURE_OK(context,
|
||||
GetOutputSafe(context, node, kOutputTensor, &output));
|
||||
|
||||
switch (input_weights->type) {
|
||||
case kTfLiteFloat32:
|
||||
@ -353,12 +382,23 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
||||
case kTfLiteInt8: {
|
||||
// TODO(mirkov): implement eval with quantized inputs as well.
|
||||
auto* op_data = reinterpret_cast<OpData*>(node->user_data);
|
||||
TfLiteTensor* input_quantized = GetTemporary(context, node, 0);
|
||||
TfLiteTensor* hidden_state_quantized = GetTemporary(context, node, 1);
|
||||
TfLiteTensor* scaling_factors = GetTemporary(context, node, 2);
|
||||
TfLiteTensor* accum_scratch = GetTemporary(context, node, 3);
|
||||
TfLiteTensor* zero_points = GetTemporary(context, node, 4);
|
||||
TfLiteTensor* row_sums = GetTemporary(context, node, 5);
|
||||
TfLiteTensor* input_quantized;
|
||||
TF_LITE_ENSURE_OK(context,
|
||||
GetTemporarySafe(context, node, 0, &input_quantized));
|
||||
TfLiteTensor* hidden_state_quantized;
|
||||
TF_LITE_ENSURE_OK(
|
||||
context, GetTemporarySafe(context, node, 1, &hidden_state_quantized));
|
||||
TfLiteTensor* scaling_factors;
|
||||
TF_LITE_ENSURE_OK(context,
|
||||
GetTemporarySafe(context, node, 2, &scaling_factors));
|
||||
TfLiteTensor* accum_scratch;
|
||||
TF_LITE_ENSURE_OK(context,
|
||||
GetTemporarySafe(context, node, 3, &accum_scratch));
|
||||
TfLiteTensor* zero_points;
|
||||
TF_LITE_ENSURE_OK(context,
|
||||
GetTemporarySafe(context, node, 4, &zero_points));
|
||||
TfLiteTensor* row_sums;
|
||||
TF_LITE_ENSURE_OK(context, GetTemporarySafe(context, node, 5, &row_sums));
|
||||
return EvalHybrid(input, input_weights, recurrent_weights, bias, params,
|
||||
input_quantized, hidden_state_quantized,
|
||||
scaling_factors, hidden_state, output, zero_points,
|
||||
|
@ -44,11 +44,14 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
|
||||
TF_LITE_ENSURE_EQ(context, NumInputs(node), 1);
|
||||
TF_LITE_ENSURE_EQ(context, NumOutputs(node), 2);
|
||||
const TfLiteTensor* input = GetInput(context, node, 0);
|
||||
TfLiteTensor* output_unique_tensor =
|
||||
GetOutput(context, node, kOutputUniqueTensor);
|
||||
TfLiteTensor* output_index_tensor =
|
||||
GetOutput(context, node, kOutputIndexTensor);
|
||||
const TfLiteTensor* input;
|
||||
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 0, &input));
|
||||
TfLiteTensor* output_unique_tensor;
|
||||
TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, kOutputUniqueTensor,
|
||||
&output_unique_tensor));
|
||||
TfLiteTensor* output_index_tensor;
|
||||
TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, kOutputIndexTensor,
|
||||
&output_index_tensor));
|
||||
|
||||
// The op only supports 1D input.
|
||||
TF_LITE_ENSURE_EQ(context, NumDimensions(input), 1);
|
||||
@ -70,7 +73,8 @@ TfLiteStatus EvalImpl(TfLiteContext* context, const TfLiteTensor* input,
|
||||
// Note that we prefer to use map than unordered_map as it showed less
|
||||
// increase in the binary size.
|
||||
std::map<T, int> unique_values;
|
||||
TfLiteTensor* output_indexes = GetOutput(context, node, 1);
|
||||
TfLiteTensor* output_indexes;
|
||||
TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, 1, &output_indexes));
|
||||
std::vector<T> output_values;
|
||||
I* indexes = GetTensorData<I>(output_indexes);
|
||||
const T* data = GetTensorData<T>(input);
|
||||
@ -88,7 +92,8 @@ TfLiteStatus EvalImpl(TfLiteContext* context, const TfLiteTensor* input,
|
||||
}
|
||||
}
|
||||
// Allocate output tensor.
|
||||
TfLiteTensor* unique_output = GetOutput(context, node, 0);
|
||||
TfLiteTensor* unique_output;
|
||||
TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, 0, &unique_output));
|
||||
std::unique_ptr<TfLiteIntArray, void (*)(TfLiteIntArray*)> shape(
|
||||
TfLiteIntArrayCreate(NumDimensions(input)), TfLiteIntArrayFree);
|
||||
shape->data[0] = unique_values.size();
|
||||
@ -127,8 +132,11 @@ TfLiteStatus EvalImpl(TfLiteContext* context, const TfLiteTensor* input,
|
||||
} // namespace
|
||||
|
||||
TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
||||
const TfLiteTensor* input = GetInput(context, node, 0);
|
||||
TfLiteTensor* output_index_tensor = GetOutput(context, node, 1);
|
||||
const TfLiteTensor* input;
|
||||
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 0, &input));
|
||||
TfLiteTensor* output_index_tensor;
|
||||
TF_LITE_ENSURE_OK(context,
|
||||
GetOutputSafe(context, node, 1, &output_index_tensor));
|
||||
TF_LITE_ENSURE_EQ(context, NumElements(output_index_tensor),
|
||||
NumElements(input));
|
||||
|
||||
|
@ -38,7 +38,8 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
TF_LITE_ENSURE_EQ(context, NumInputs(node), 1);
|
||||
TF_LITE_ENSURE_EQ(context, NumOutputs(node), data->num);
|
||||
|
||||
const TfLiteTensor* input = GetInput(context, node, kInputTensor);
|
||||
const TfLiteTensor* input;
|
||||
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputTensor, &input));
|
||||
TF_LITE_ENSURE(context, NumElements(input) > 0);
|
||||
int axis = data->axis;
|
||||
if (axis < 0) {
|
||||
@ -67,7 +68,8 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
TF_LITE_ENSURE_EQ(context, data->num, input_shape->data[axis]);
|
||||
for (int i = 0; i < data->num; ++i) {
|
||||
TfLiteIntArray* copied_output_shape = TfLiteIntArrayCopy(output_shape);
|
||||
TfLiteTensor* output = GetOutput(context, node, i);
|
||||
TfLiteTensor* output;
|
||||
TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, i, &output));
|
||||
TF_LITE_ENSURE_TYPES_EQ(context, output->type, input->type);
|
||||
// Guarantee input/output quantization params match as we do not support
|
||||
// rescaling of unpacked quantized tensors.
|
||||
@ -98,7 +100,8 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
||||
const TfLiteUnpackParams* data =
|
||||
reinterpret_cast<TfLiteUnpackParams*>(node->builtin_data);
|
||||
|
||||
const TfLiteTensor* input = GetInput(context, node, kInputTensor);
|
||||
const TfLiteTensor* input;
|
||||
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputTensor, &input));
|
||||
switch (input->type) {
|
||||
case kTfLiteFloat32: {
|
||||
UnpackImpl<float>(context, node, input, data->num, data->axis);
|
||||
|
@ -56,9 +56,12 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
TF_LITE_ENSURE_EQ(context, NumInputs(node), 1);
|
||||
TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
|
||||
|
||||
const TfLiteTensor* cond_tensor =
|
||||
GetInput(context, node, kInputConditionTensor);
|
||||
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
|
||||
const TfLiteTensor* cond_tensor;
|
||||
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputConditionTensor,
|
||||
&cond_tensor));
|
||||
TfLiteTensor* output;
|
||||
TF_LITE_ENSURE_OK(context,
|
||||
GetOutputSafe(context, node, kOutputTensor, &output));
|
||||
|
||||
if (cond_tensor->type != kTfLiteBool) {
|
||||
context->ReportError(context,
|
||||
@ -81,9 +84,12 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
}
|
||||
|
||||
TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
||||
const TfLiteTensor* cond_tensor =
|
||||
GetInput(context, node, kInputConditionTensor);
|
||||
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
|
||||
const TfLiteTensor* cond_tensor;
|
||||
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputConditionTensor,
|
||||
&cond_tensor));
|
||||
TfLiteTensor* output;
|
||||
TF_LITE_ENSURE_OK(context,
|
||||
GetOutputSafe(context, node, kOutputTensor, &output));
|
||||
|
||||
if (IsDynamicTensor(output)) {
|
||||
TF_LITE_ENSURE_OK(context,
|
||||
|
@ -195,7 +195,8 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
}
|
||||
}
|
||||
for (int i = 0; i < num_inputs; ++i) {
|
||||
TfLiteTensor* output = GetOutput(context, node, i);
|
||||
TfLiteTensor* output;
|
||||
TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, i, &output));
|
||||
if (op_data->body_has_dynamic_output_tensors) {
|
||||
SetTensorToDynamic(output);
|
||||
} else {
|
||||
|
@ -32,8 +32,11 @@ constexpr int kOutputTensor = 0;
|
||||
TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
TF_LITE_ENSURE_EQ(context, NumInputs(node), 1);
|
||||
TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
|
||||
const TfLiteTensor* input = GetInput(context, node, kInputTensor);
|
||||
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
|
||||
const TfLiteTensor* input;
|
||||
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputTensor, &input));
|
||||
TfLiteTensor* output;
|
||||
TF_LITE_ENSURE_OK(context,
|
||||
GetOutputSafe(context, node, kOutputTensor, &output));
|
||||
output->type = input->type;
|
||||
|
||||
return context->ResizeTensor(context, output,
|
||||
@ -41,8 +44,11 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
}
|
||||
|
||||
TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
||||
const TfLiteTensor* input = GetInput(context, node, kInputTensor);
|
||||
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
|
||||
const TfLiteTensor* input;
|
||||
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputTensor, &input));
|
||||
TfLiteTensor* output;
|
||||
TF_LITE_ENSURE_OK(context,
|
||||
GetOutputSafe(context, node, kOutputTensor, &output));
|
||||
const int num_elements = NumElements(input);
|
||||
switch (input->type) {
|
||||
case kTfLiteInt64:
|
||||
|
Loading…
Reference in New Issue
Block a user