Misc op fixes and renaming
Rename TRANSPOSE_CONV_2D to TRANSPOSE_CONV. Fix invariant checks in space_to_batch_nd and batch_to_space_nd.
This commit is contained in:
parent
452fc5dead
commit
7082bbb6e7
@ -29,8 +29,12 @@ constexpr int kBlockShapeTensor = 1;
|
|||||||
constexpr int kCropsTensor = 2;
|
constexpr int kCropsTensor = 2;
|
||||||
constexpr int kOutputTensor = 0;
|
constexpr int kOutputTensor = 0;
|
||||||
|
|
||||||
constexpr int kInputDims = 4;
|
// Currently, only 3D NHC and 4D NHWC input/output op_context are supported.
|
||||||
constexpr int kOutputDims = 4;
|
// In case of 3D input, it will be extended to 3D NHWC by adding W=1.
|
||||||
|
// The 4D array need to have exactly 2 spatial dimensions.
|
||||||
|
// TODO(b/149952582): Support arbitrary dimension in SpaceToBatchND.
|
||||||
|
const int kInputOutputMinDimensionNum = 3;
|
||||||
|
const int kInputOutputMaxDimensionNum = 4;
|
||||||
|
|
||||||
TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
||||||
TF_LITE_ENSURE_EQ(context, NumInputs(node), 3);
|
TF_LITE_ENSURE_EQ(context, NumInputs(node), 3);
|
||||||
@ -40,15 +44,12 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
|||||||
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
|
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
|
||||||
TF_LITE_ENSURE(context, input != nullptr && output != nullptr);
|
TF_LITE_ENSURE(context, input != nullptr && output != nullptr);
|
||||||
|
|
||||||
// Only 4D input and output tensors are supported for this op on TFLM.
|
TF_LITE_ENSURE(context, NumDimensions(input) >= kInputOutputMinDimensionNum);
|
||||||
TF_LITE_ENSURE_EQ(context, NumDimensions(input), kInputDims);
|
TF_LITE_ENSURE(context, NumDimensions(output) >= kInputOutputMinDimensionNum);
|
||||||
TF_LITE_ENSURE_EQ(context, NumDimensions(output), kOutputDims);
|
TF_LITE_ENSURE(context, NumDimensions(input) <= kInputOutputMaxDimensionNum);
|
||||||
TF_LITE_ENSURE_EQ(context, input->type, output->type);
|
TF_LITE_ENSURE(context, NumDimensions(output) <= kInputOutputMaxDimensionNum);
|
||||||
|
TF_LITE_ENSURE_TYPES_EQ(context, input->type, output->type);
|
||||||
|
|
||||||
// Input and output must have the same flat size since TFLM does not support
|
|
||||||
// tensor resizing.
|
|
||||||
TF_LITE_ENSURE_EQ(context, GetTensorShape(input).FlatSize(),
|
|
||||||
GetTensorShape(output).FlatSize());
|
|
||||||
return kTfLiteOk;
|
return kTfLiteOk;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -41,7 +41,7 @@ TfLiteRegistration Register_SHAPE();
|
|||||||
TfLiteRegistration Register_SOFTMAX();
|
TfLiteRegistration Register_SOFTMAX();
|
||||||
TfLiteRegistration Register_SPACE_TO_BATCH_ND();
|
TfLiteRegistration Register_SPACE_TO_BATCH_ND();
|
||||||
TfLiteRegistration Register_SVDF();
|
TfLiteRegistration Register_SVDF();
|
||||||
TfLiteRegistration Register_TRANSPOSE_CONV_2D();
|
TfLiteRegistration Register_TRANSPOSE_CONV();
|
||||||
TfLiteRegistration Register_ZEROS_LIKE();
|
TfLiteRegistration Register_ZEROS_LIKE();
|
||||||
|
|
||||||
namespace ops {
|
namespace ops {
|
||||||
|
@ -30,8 +30,12 @@ constexpr int kBlockShapeTensor = 1;
|
|||||||
constexpr int kCropsTensor = 2;
|
constexpr int kCropsTensor = 2;
|
||||||
constexpr int kOutputTensor = 0;
|
constexpr int kOutputTensor = 0;
|
||||||
|
|
||||||
constexpr int kInputDims = 4;
|
// Currently, only 3D NHC and 4D NHWC input/output op_context are supported.
|
||||||
constexpr int kOutputDims = 4;
|
// In case of 3D input, it will be extended to 3D NHWC by adding W=1.
|
||||||
|
// The 4D array need to have exactly 2 spatial dimensions.
|
||||||
|
// TODO(b/149952582): Support arbitrary dimension in SpaceToBatchND.
|
||||||
|
const int kInputOutputMinDimensionNum = 3;
|
||||||
|
const int kInputOutputMaxDimensionNum = 4;
|
||||||
|
|
||||||
void* Init(TfLiteContext* context, const char* buffer, size_t length) {
|
void* Init(TfLiteContext* context, const char* buffer, size_t length) {
|
||||||
TFLITE_DCHECK(context->AllocatePersistentBuffer != nullptr);
|
TFLITE_DCHECK(context->AllocatePersistentBuffer != nullptr);
|
||||||
@ -46,19 +50,12 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
|||||||
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
|
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
|
||||||
TF_LITE_ENSURE(context, input != nullptr && output != nullptr);
|
TF_LITE_ENSURE(context, input != nullptr && output != nullptr);
|
||||||
|
|
||||||
SpaceToBatchParams* params =
|
TF_LITE_ENSURE(context, NumDimensions(input) >= kInputOutputMinDimensionNum);
|
||||||
static_cast<SpaceToBatchParams*>(node->user_data);
|
TF_LITE_ENSURE(context, NumDimensions(output) >= kInputOutputMinDimensionNum);
|
||||||
params->output_offset = output->params.zero_point;
|
TF_LITE_ENSURE(context, NumDimensions(input) <= kInputOutputMaxDimensionNum);
|
||||||
|
TF_LITE_ENSURE(context, NumDimensions(output) <= kInputOutputMaxDimensionNum);
|
||||||
|
TF_LITE_ENSURE_TYPES_EQ(context, input->type, output->type);
|
||||||
|
|
||||||
// Only 4D input and output tensors are supported for this op on TFLM.
|
|
||||||
TF_LITE_ENSURE_EQ(context, NumDimensions(input), kInputDims);
|
|
||||||
TF_LITE_ENSURE_EQ(context, NumDimensions(output), kOutputDims);
|
|
||||||
TF_LITE_ENSURE_EQ(context, input->type, output->type);
|
|
||||||
|
|
||||||
// Input and output must have the same flat size since TFLM does not support
|
|
||||||
// tensor resizing.
|
|
||||||
TF_LITE_ENSURE_EQ(context, GetTensorShape(input).FlatSize(),
|
|
||||||
GetTensorShape(output).FlatSize());
|
|
||||||
return kTfLiteOk;
|
return kTfLiteOk;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -28,9 +28,11 @@ limitations under the License.
|
|||||||
namespace tflite {
|
namespace tflite {
|
||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
constexpr int kInputTensor = 0;
|
// OutputShapeTensor is tensor 0, but it is unused in TFLM because we do not
|
||||||
|
// support dynamic tensor resizing.
|
||||||
constexpr int kFilterTensor = 1;
|
constexpr int kFilterTensor = 1;
|
||||||
constexpr int kBiasTensor = 2;
|
constexpr int kInputTensor = 2;
|
||||||
|
constexpr int kBiasTensor = 3;
|
||||||
constexpr int kOutputTensor = 0;
|
constexpr int kOutputTensor = 0;
|
||||||
|
|
||||||
// Conv is quantized along dimension 0:
|
// Conv is quantized along dimension 0:
|
||||||
@ -65,9 +67,9 @@ TfLiteStatus CalculateOpData(TfLiteContext* context, TfLiteNode* node,
|
|||||||
int height, int filter_width, int filter_height,
|
int height, int filter_width, int filter_height,
|
||||||
int out_width, int out_height,
|
int out_width, int out_height,
|
||||||
const TfLiteType data_type, OpData* data) {
|
const TfLiteType data_type, OpData* data) {
|
||||||
bool has_bias = node->inputs->size == 3;
|
bool has_bias = node->inputs->size == 4;
|
||||||
// Check number of inputs/outputs
|
// Check number of inputs/outputs
|
||||||
TF_LITE_ENSURE(context, has_bias || node->inputs->size == 2);
|
TF_LITE_ENSURE(context, has_bias || node->inputs->size == 3);
|
||||||
TF_LITE_ENSURE_EQ(context, node->outputs->size, 1);
|
TF_LITE_ENSURE_EQ(context, node->outputs->size, 1);
|
||||||
|
|
||||||
// Matching GetWindowedOutputSize in TensorFlow.
|
// Matching GetWindowedOutputSize in TensorFlow.
|
||||||
@ -198,7 +200,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
|||||||
const TfLiteEvalTensor* filter =
|
const TfLiteEvalTensor* filter =
|
||||||
tflite::micro::GetEvalInput(context, node, kFilterTensor);
|
tflite::micro::GetEvalInput(context, node, kFilterTensor);
|
||||||
const TfLiteEvalTensor* bias =
|
const TfLiteEvalTensor* bias =
|
||||||
(NumInputs(node) == 3)
|
(NumInputs(node) == 4)
|
||||||
? tflite::micro::GetEvalInput(context, node, kBiasTensor)
|
? tflite::micro::GetEvalInput(context, node, kBiasTensor)
|
||||||
: nullptr;
|
: nullptr;
|
||||||
TfLiteEvalTensor* output =
|
TfLiteEvalTensor* output =
|
||||||
@ -251,7 +253,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
|||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
TfLiteRegistration Register_TRANSPOSE_CONV_2D() {
|
TfLiteRegistration Register_TRANSPOSE_CONV() {
|
||||||
return {/*init=*/Init,
|
return {/*init=*/Init,
|
||||||
/*free=*/nullptr,
|
/*free=*/nullptr,
|
||||||
/*prepare=*/Prepare,
|
/*prepare=*/Prepare,
|
||||||
|
Loading…
x
Reference in New Issue
Block a user