Initial port of Transpose from lite to micro
This commit is contained in:
parent
b5c8d770f5
commit
988a602ddd
181
tensorflow/lite/micro/kernels/transpose.cc
Normal file
181
tensorflow/lite/micro/kernels/transpose.cc
Normal file
@ -0,0 +1,181 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#include <stdint.h>
|
||||
|
||||
#include "tensorflow/lite/c/common.h"
|
||||
#include "tensorflow/lite/kernels/internal/compatibility.h"
|
||||
#include "tensorflow/lite/kernels/internal/optimized/optimized_ops.h"
|
||||
#include "tensorflow/lite/kernels/internal/reference/reference_ops.h"
|
||||
#include "tensorflow/lite/kernels/internal/tensor.h"
|
||||
#include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
|
||||
#include "tensorflow/lite/kernels/internal/types.h"
|
||||
#include "tensorflow/lite/kernels/kernel_util.h"
|
||||
|
||||
namespace tflite {
|
||||
namespace ops {
|
||||
namespace builtin {
|
||||
namespace transpose {
|
||||
|
||||
// This file has two implementations of Transpose.
|
||||
enum KernelType {
|
||||
kReference,
|
||||
kGenericOptimized,
|
||||
};
|
||||
|
||||
struct TransposeContext {
|
||||
TransposeContext(TfLiteContext* context, TfLiteNode* node) {
|
||||
input = GetInput(context, node, 0);
|
||||
perm = GetInput(context, node, 1);
|
||||
output = GetOutput(context, node, 0);
|
||||
}
|
||||
const TfLiteTensor* input;
|
||||
const TfLiteTensor* perm;
|
||||
TfLiteTensor* output;
|
||||
};
|
||||
|
||||
TfLiteStatus ResizeOutputTensor(TfLiteContext* context,
|
||||
TransposeContext* op_context) {
|
||||
int dims = NumDimensions(op_context->input);
|
||||
const int* perm_data = GetTensorData<int32_t>(op_context->perm);
|
||||
|
||||
// Ensure validity of the permutations tensor as a 1D tensor.
|
||||
TF_LITE_ENSURE_EQ(context, NumDimensions(op_context->perm), 1);
|
||||
TF_LITE_ENSURE_EQ(context, op_context->perm->dims->data[0], dims);
|
||||
for (int idx = 0; idx < dims; ++idx) {
|
||||
TF_LITE_ENSURE_MSG(context, (perm_data[idx] >= 0 && perm_data[idx] < dims),
|
||||
"Transpose op permutations array is out of bounds.");
|
||||
}
|
||||
|
||||
// Determine size of output tensor.
|
||||
TfLiteIntArray* input_size = op_context->input->dims;
|
||||
TfLiteIntArray* output_size = TfLiteIntArrayCopy(input_size);
|
||||
for (int idx = 0; idx < dims; ++idx) {
|
||||
output_size->data[idx] = input_size->data[perm_data[idx]];
|
||||
}
|
||||
|
||||
return context->ResizeTensor(context, op_context->output, output_size);
|
||||
}
|
||||
|
||||
TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
TF_LITE_ENSURE_EQ(context, NumInputs(node), 2);
|
||||
TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
|
||||
|
||||
TransposeContext op_context(context, node);
|
||||
|
||||
// Ensure validity of input tensor.
|
||||
TF_LITE_ENSURE_MSG(context, NumDimensions(op_context.input) <= 5,
|
||||
"Transpose op only supports 1D-5D input arrays.");
|
||||
TF_LITE_ENSURE_TYPES_EQ(context, op_context.input->type,
|
||||
op_context.output->type);
|
||||
|
||||
if (!IsConstantTensor(op_context.perm)) {
|
||||
SetTensorToDynamic(op_context.output);
|
||||
return kTfLiteOk;
|
||||
}
|
||||
return ResizeOutputTensor(context, &op_context);
|
||||
}
|
||||
|
||||
template <KernelType kernel_type>
|
||||
TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
||||
TransposeContext op_context(context, node);
|
||||
|
||||
// Resize the output tensor if the output tensor is dynamic.
|
||||
if (IsDynamicTensor(op_context.output)) {
|
||||
TF_LITE_ENSURE_OK(context, ResizeOutputTensor(context, &op_context));
|
||||
}
|
||||
|
||||
const int* perm_data = GetTensorData<int32_t>(op_context.perm);
|
||||
const int size = op_context.perm->dims->data[0];
|
||||
TransposeParams params;
|
||||
params.perm_count = size;
|
||||
for (int i = 0; i < size; ++i) {
|
||||
params.perm[i] = perm_data[i];
|
||||
}
|
||||
|
||||
#define TF_LITE_TRANSPOSE(type, scalar) \
|
||||
type::Transpose(params, GetTensorShape(op_context.input), \
|
||||
GetTensorData<scalar>(op_context.input), \
|
||||
GetTensorShape(op_context.output), \
|
||||
GetTensorData<scalar>(op_context.output))
|
||||
|
||||
// Transpose kernel only does rearranging values not numeric evaluations on
|
||||
// each cell. It's safe to implement per size of scalar type and this trick
|
||||
// keeps the total code size in a reasonable range.
|
||||
switch (op_context.input->type) {
|
||||
case kTfLiteFloat32:
|
||||
case kTfLiteInt32:
|
||||
if (kernel_type == kGenericOptimized) {
|
||||
TF_LITE_TRANSPOSE(optimized_ops, int32_t);
|
||||
} else {
|
||||
TF_LITE_TRANSPOSE(reference_ops, int32_t);
|
||||
}
|
||||
break;
|
||||
case kTfLiteUInt8:
|
||||
case kTfLiteInt8:
|
||||
if (kernel_type == kGenericOptimized) {
|
||||
TF_LITE_TRANSPOSE(optimized_ops, int8_t);
|
||||
} else {
|
||||
TF_LITE_TRANSPOSE(reference_ops, int8_t);
|
||||
}
|
||||
break;
|
||||
case kTfLiteInt16:
|
||||
TF_LITE_TRANSPOSE(reference_ops, int16_t);
|
||||
break;
|
||||
case kTfLiteInt64:
|
||||
TF_LITE_TRANSPOSE(reference_ops, int64_t);
|
||||
break;
|
||||
case kTfLiteBool:
|
||||
if (sizeof(bool) == 1) {
|
||||
if (kernel_type == kGenericOptimized) {
|
||||
TF_LITE_TRANSPOSE(optimized_ops, int8_t);
|
||||
} else {
|
||||
TF_LITE_TRANSPOSE(reference_ops, int8_t);
|
||||
}
|
||||
} else {
|
||||
TF_LITE_TRANSPOSE(reference_ops, bool);
|
||||
}
|
||||
break;
|
||||
default:
|
||||
TF_LITE_KERNEL_LOG(context,
|
||||
"Type %s is currently not supported by Transpose.",
|
||||
TfLiteTypeGetName(op_context.input->type));
|
||||
return kTfLiteError;
|
||||
}
|
||||
#undef TF_LITE_TRANSPOSE
|
||||
|
||||
return kTfLiteOk;
|
||||
}
|
||||
|
||||
} // namespace transpose
|
||||
|
||||
TfLiteRegistration* Register_TRANSPOSE_REF() {
|
||||
static TfLiteRegistration r = {nullptr, nullptr, transpose::Prepare,
|
||||
transpose::Eval<transpose::kReference>};
|
||||
return &r;
|
||||
}
|
||||
|
||||
TfLiteRegistration* Register_TRANSPOSE_GENERIC_OPTIMIZED() {
|
||||
static TfLiteRegistration r = {nullptr, nullptr, transpose::Prepare,
|
||||
transpose::Eval<transpose::kGenericOptimized>};
|
||||
return &r;
|
||||
}
|
||||
|
||||
TfLiteRegistration* Register_TRANSPOSE() {
|
||||
return Register_TRANSPOSE_GENERIC_OPTIMIZED();
|
||||
}
|
||||
|
||||
} // namespace builtin
|
||||
} // namespace ops
|
||||
} // namespace tflite
|
Loading…
x
Reference in New Issue
Block a user