Misc op fixes and renaming

Rename TRANSPOSE_CONV_2D to TRANSPOSE_CONV. Fix invariant checks in space_to_batch_nd and batch_to_space_nd.
This commit is contained in:
Nat Jeffries 2021-02-10 11:24:10 -08:00
parent 452fc5dead
commit 7082bbb6e7
4 changed files with 31 additions and 31 deletions

View File

@ -29,8 +29,12 @@ constexpr int kBlockShapeTensor = 1;
constexpr int kCropsTensor = 2;
constexpr int kOutputTensor = 0;
constexpr int kInputDims = 4;
constexpr int kOutputDims = 4;
// Currently, only 3D NHC and 4D NHWC input/output op_context are supported.
// In case of 3D input, it will be extended to 3D NHWC by adding W=1.
// The 4D array need to have exactly 2 spatial dimensions.
// TODO(b/149952582): Support arbitrary dimension in SpaceToBatchND.
const int kInputOutputMinDimensionNum = 3;
const int kInputOutputMaxDimensionNum = 4;
TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_ENSURE_EQ(context, NumInputs(node), 3);
@ -40,15 +44,12 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
TF_LITE_ENSURE(context, input != nullptr && output != nullptr);
// Only 4D input and output tensors are supported for this op on TFLM.
TF_LITE_ENSURE_EQ(context, NumDimensions(input), kInputDims);
TF_LITE_ENSURE_EQ(context, NumDimensions(output), kOutputDims);
TF_LITE_ENSURE_EQ(context, input->type, output->type);
TF_LITE_ENSURE(context, NumDimensions(input) >= kInputOutputMinDimensionNum);
TF_LITE_ENSURE(context, NumDimensions(output) >= kInputOutputMinDimensionNum);
TF_LITE_ENSURE(context, NumDimensions(input) <= kInputOutputMaxDimensionNum);
TF_LITE_ENSURE(context, NumDimensions(output) <= kInputOutputMaxDimensionNum);
TF_LITE_ENSURE_TYPES_EQ(context, input->type, output->type);
// Input and output must have the same flat size since TFLM does not support
// tensor resizing.
TF_LITE_ENSURE_EQ(context, GetTensorShape(input).FlatSize(),
GetTensorShape(output).FlatSize());
return kTfLiteOk;
}

View File

@ -41,7 +41,7 @@ TfLiteRegistration Register_SHAPE();
TfLiteRegistration Register_SOFTMAX();
TfLiteRegistration Register_SPACE_TO_BATCH_ND();
TfLiteRegistration Register_SVDF();
TfLiteRegistration Register_TRANSPOSE_CONV_2D();
TfLiteRegistration Register_TRANSPOSE_CONV();
TfLiteRegistration Register_ZEROS_LIKE();
namespace ops {

View File

@ -30,8 +30,12 @@ constexpr int kBlockShapeTensor = 1;
constexpr int kCropsTensor = 2;
constexpr int kOutputTensor = 0;
constexpr int kInputDims = 4;
constexpr int kOutputDims = 4;
// Currently, only 3D NHC and 4D NHWC input/output op_context are supported.
// In case of 3D input, it will be extended to 3D NHWC by adding W=1.
// The 4D array need to have exactly 2 spatial dimensions.
// TODO(b/149952582): Support arbitrary dimension in SpaceToBatchND.
const int kInputOutputMinDimensionNum = 3;
const int kInputOutputMaxDimensionNum = 4;
void* Init(TfLiteContext* context, const char* buffer, size_t length) {
TFLITE_DCHECK(context->AllocatePersistentBuffer != nullptr);
@ -46,19 +50,12 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
TF_LITE_ENSURE(context, input != nullptr && output != nullptr);
SpaceToBatchParams* params =
static_cast<SpaceToBatchParams*>(node->user_data);
params->output_offset = output->params.zero_point;
TF_LITE_ENSURE(context, NumDimensions(input) >= kInputOutputMinDimensionNum);
TF_LITE_ENSURE(context, NumDimensions(output) >= kInputOutputMinDimensionNum);
TF_LITE_ENSURE(context, NumDimensions(input) <= kInputOutputMaxDimensionNum);
TF_LITE_ENSURE(context, NumDimensions(output) <= kInputOutputMaxDimensionNum);
TF_LITE_ENSURE_TYPES_EQ(context, input->type, output->type);
// Only 4D input and output tensors are supported for this op on TFLM.
TF_LITE_ENSURE_EQ(context, NumDimensions(input), kInputDims);
TF_LITE_ENSURE_EQ(context, NumDimensions(output), kOutputDims);
TF_LITE_ENSURE_EQ(context, input->type, output->type);
// Input and output must have the same flat size since TFLM does not support
// tensor resizing.
TF_LITE_ENSURE_EQ(context, GetTensorShape(input).FlatSize(),
GetTensorShape(output).FlatSize());
return kTfLiteOk;
}

View File

@ -28,9 +28,11 @@ limitations under the License.
namespace tflite {
namespace {
constexpr int kInputTensor = 0;
// OutputShapeTensor is tensor 0, but it is unused in TFLM because we do not
// support dynamic tensor resizing.
constexpr int kFilterTensor = 1;
constexpr int kBiasTensor = 2;
constexpr int kInputTensor = 2;
constexpr int kBiasTensor = 3;
constexpr int kOutputTensor = 0;
// Conv is quantized along dimension 0:
@ -65,9 +67,9 @@ TfLiteStatus CalculateOpData(TfLiteContext* context, TfLiteNode* node,
int height, int filter_width, int filter_height,
int out_width, int out_height,
const TfLiteType data_type, OpData* data) {
bool has_bias = node->inputs->size == 3;
bool has_bias = node->inputs->size == 4;
// Check number of inputs/outputs
TF_LITE_ENSURE(context, has_bias || node->inputs->size == 2);
TF_LITE_ENSURE(context, has_bias || node->inputs->size == 3);
TF_LITE_ENSURE_EQ(context, node->outputs->size, 1);
// Matching GetWindowedOutputSize in TensorFlow.
@ -198,7 +200,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
const TfLiteEvalTensor* filter =
tflite::micro::GetEvalInput(context, node, kFilterTensor);
const TfLiteEvalTensor* bias =
(NumInputs(node) == 3)
(NumInputs(node) == 4)
? tflite::micro::GetEvalInput(context, node, kBiasTensor)
: nullptr;
TfLiteEvalTensor* output =
@ -251,7 +253,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
} // namespace
TfLiteRegistration Register_TRANSPOSE_CONV_2D() {
TfLiteRegistration Register_TRANSPOSE_CONV() {
return {/*init=*/Init,
/*free=*/nullptr,
/*prepare=*/Prepare,