Misc op fixes and renaming
Rename TRANSPOSE_CONV_2D to TRANSPOSE_CONV. Fix invariant checks in space_to_batch_nd and batch_to_space_nd.
This commit is contained in:
parent
452fc5dead
commit
7082bbb6e7
@ -29,8 +29,12 @@ constexpr int kBlockShapeTensor = 1;
|
||||
constexpr int kCropsTensor = 2;
|
||||
constexpr int kOutputTensor = 0;
|
||||
|
||||
constexpr int kInputDims = 4;
|
||||
constexpr int kOutputDims = 4;
|
||||
// Currently, only 3D NHC and 4D NHWC input/output op_context are supported.
|
||||
// In case of 3D input, it will be extended to 3D NHWC by adding W=1.
|
||||
// The 4D array need to have exactly 2 spatial dimensions.
|
||||
// TODO(b/149952582): Support arbitrary dimension in SpaceToBatchND.
|
||||
const int kInputOutputMinDimensionNum = 3;
|
||||
const int kInputOutputMaxDimensionNum = 4;
|
||||
|
||||
TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
TF_LITE_ENSURE_EQ(context, NumInputs(node), 3);
|
||||
@ -40,15 +44,12 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
|
||||
TF_LITE_ENSURE(context, input != nullptr && output != nullptr);
|
||||
|
||||
// Only 4D input and output tensors are supported for this op on TFLM.
|
||||
TF_LITE_ENSURE_EQ(context, NumDimensions(input), kInputDims);
|
||||
TF_LITE_ENSURE_EQ(context, NumDimensions(output), kOutputDims);
|
||||
TF_LITE_ENSURE_EQ(context, input->type, output->type);
|
||||
TF_LITE_ENSURE(context, NumDimensions(input) >= kInputOutputMinDimensionNum);
|
||||
TF_LITE_ENSURE(context, NumDimensions(output) >= kInputOutputMinDimensionNum);
|
||||
TF_LITE_ENSURE(context, NumDimensions(input) <= kInputOutputMaxDimensionNum);
|
||||
TF_LITE_ENSURE(context, NumDimensions(output) <= kInputOutputMaxDimensionNum);
|
||||
TF_LITE_ENSURE_TYPES_EQ(context, input->type, output->type);
|
||||
|
||||
// Input and output must have the same flat size since TFLM does not support
|
||||
// tensor resizing.
|
||||
TF_LITE_ENSURE_EQ(context, GetTensorShape(input).FlatSize(),
|
||||
GetTensorShape(output).FlatSize());
|
||||
return kTfLiteOk;
|
||||
}
|
||||
|
||||
|
@ -41,7 +41,7 @@ TfLiteRegistration Register_SHAPE();
|
||||
TfLiteRegistration Register_SOFTMAX();
|
||||
TfLiteRegistration Register_SPACE_TO_BATCH_ND();
|
||||
TfLiteRegistration Register_SVDF();
|
||||
TfLiteRegistration Register_TRANSPOSE_CONV_2D();
|
||||
TfLiteRegistration Register_TRANSPOSE_CONV();
|
||||
TfLiteRegistration Register_ZEROS_LIKE();
|
||||
|
||||
namespace ops {
|
||||
|
@ -30,8 +30,12 @@ constexpr int kBlockShapeTensor = 1;
|
||||
constexpr int kCropsTensor = 2;
|
||||
constexpr int kOutputTensor = 0;
|
||||
|
||||
constexpr int kInputDims = 4;
|
||||
constexpr int kOutputDims = 4;
|
||||
// Currently, only 3D NHC and 4D NHWC input/output op_context are supported.
|
||||
// In case of 3D input, it will be extended to 3D NHWC by adding W=1.
|
||||
// The 4D array need to have exactly 2 spatial dimensions.
|
||||
// TODO(b/149952582): Support arbitrary dimension in SpaceToBatchND.
|
||||
const int kInputOutputMinDimensionNum = 3;
|
||||
const int kInputOutputMaxDimensionNum = 4;
|
||||
|
||||
void* Init(TfLiteContext* context, const char* buffer, size_t length) {
|
||||
TFLITE_DCHECK(context->AllocatePersistentBuffer != nullptr);
|
||||
@ -46,19 +50,12 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
|
||||
TF_LITE_ENSURE(context, input != nullptr && output != nullptr);
|
||||
|
||||
SpaceToBatchParams* params =
|
||||
static_cast<SpaceToBatchParams*>(node->user_data);
|
||||
params->output_offset = output->params.zero_point;
|
||||
TF_LITE_ENSURE(context, NumDimensions(input) >= kInputOutputMinDimensionNum);
|
||||
TF_LITE_ENSURE(context, NumDimensions(output) >= kInputOutputMinDimensionNum);
|
||||
TF_LITE_ENSURE(context, NumDimensions(input) <= kInputOutputMaxDimensionNum);
|
||||
TF_LITE_ENSURE(context, NumDimensions(output) <= kInputOutputMaxDimensionNum);
|
||||
TF_LITE_ENSURE_TYPES_EQ(context, input->type, output->type);
|
||||
|
||||
// Only 4D input and output tensors are supported for this op on TFLM.
|
||||
TF_LITE_ENSURE_EQ(context, NumDimensions(input), kInputDims);
|
||||
TF_LITE_ENSURE_EQ(context, NumDimensions(output), kOutputDims);
|
||||
TF_LITE_ENSURE_EQ(context, input->type, output->type);
|
||||
|
||||
// Input and output must have the same flat size since TFLM does not support
|
||||
// tensor resizing.
|
||||
TF_LITE_ENSURE_EQ(context, GetTensorShape(input).FlatSize(),
|
||||
GetTensorShape(output).FlatSize());
|
||||
return kTfLiteOk;
|
||||
}
|
||||
|
||||
|
@ -28,9 +28,11 @@ limitations under the License.
|
||||
namespace tflite {
|
||||
namespace {
|
||||
|
||||
constexpr int kInputTensor = 0;
|
||||
// OutputShapeTensor is tensor 0, but it is unused in TFLM because we do not
|
||||
// support dynamic tensor resizing.
|
||||
constexpr int kFilterTensor = 1;
|
||||
constexpr int kBiasTensor = 2;
|
||||
constexpr int kInputTensor = 2;
|
||||
constexpr int kBiasTensor = 3;
|
||||
constexpr int kOutputTensor = 0;
|
||||
|
||||
// Conv is quantized along dimension 0:
|
||||
@ -65,9 +67,9 @@ TfLiteStatus CalculateOpData(TfLiteContext* context, TfLiteNode* node,
|
||||
int height, int filter_width, int filter_height,
|
||||
int out_width, int out_height,
|
||||
const TfLiteType data_type, OpData* data) {
|
||||
bool has_bias = node->inputs->size == 3;
|
||||
bool has_bias = node->inputs->size == 4;
|
||||
// Check number of inputs/outputs
|
||||
TF_LITE_ENSURE(context, has_bias || node->inputs->size == 2);
|
||||
TF_LITE_ENSURE(context, has_bias || node->inputs->size == 3);
|
||||
TF_LITE_ENSURE_EQ(context, node->outputs->size, 1);
|
||||
|
||||
// Matching GetWindowedOutputSize in TensorFlow.
|
||||
@ -198,7 +200,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
||||
const TfLiteEvalTensor* filter =
|
||||
tflite::micro::GetEvalInput(context, node, kFilterTensor);
|
||||
const TfLiteEvalTensor* bias =
|
||||
(NumInputs(node) == 3)
|
||||
(NumInputs(node) == 4)
|
||||
? tflite::micro::GetEvalInput(context, node, kBiasTensor)
|
||||
: nullptr;
|
||||
TfLiteEvalTensor* output =
|
||||
@ -251,7 +253,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
||||
|
||||
} // namespace
|
||||
|
||||
TfLiteRegistration Register_TRANSPOSE_CONV_2D() {
|
||||
TfLiteRegistration Register_TRANSPOSE_CONV() {
|
||||
return {/*init=*/Init,
|
||||
/*free=*/nullptr,
|
||||
/*prepare=*/Prepare,
|
||||
|
Loading…
x
Reference in New Issue
Block a user