diff --git a/tensorflow/lite/tools/evaluation/proto/preprocessing_steps.proto b/tensorflow/lite/tools/evaluation/proto/preprocessing_steps.proto index 0c9710639c1..05b0d53c7cd 100644 --- a/tensorflow/lite/tools/evaluation/proto/preprocessing_steps.proto +++ b/tensorflow/lite/tools/evaluation/proto/preprocessing_steps.proto @@ -56,9 +56,9 @@ message CroppingParams { float cropping_fraction = 1 [default = 0.875]; // The target size after cropping. ImageSize target_size = 2; - // Crops to a square image. - bool square_cropping = 3; } + // Crops to a square image. + optional bool square_cropping = 3; } // Defines parameters for bilinear central-resizing. diff --git a/tensorflow/lite/tools/evaluation/stages/image_classification_stage.cc b/tensorflow/lite/tools/evaluation/stages/image_classification_stage.cc index 4d4f83c69f5..c9f8f832441 100644 --- a/tensorflow/lite/tools/evaluation/stages/image_classification_stage.cc +++ b/tensorflow/lite/tools/evaluation/stages/image_classification_stage.cc @@ -67,8 +67,7 @@ TfLiteStatus ImageClassificationStage::Init() { // ImagePreprocessingStage tflite::evaluation::ImagePreprocessingConfigBuilder builder( "image_preprocessing", input_type); - builder.AddSquareCroppingStep(); - builder.AddCroppingStep(kCroppingFraction); + builder.AddCroppingStep(kCroppingFraction, true /*square*/); builder.AddResizingStep(input_shape->data[2], input_shape->data[1], false); builder.AddDefaultNormalizationStep(); preprocessing_stage_.reset(new ImagePreprocessingStage(builder.build())); diff --git a/tensorflow/lite/tools/evaluation/stages/image_preprocessing_stage.cc b/tensorflow/lite/tools/evaluation/stages/image_preprocessing_stage.cc index 3f1a922ac79..dd434a1c882 100644 --- a/tensorflow/lite/tools/evaluation/stages/image_preprocessing_stage.cc +++ b/tensorflow/lite/tools/evaluation/stages/image_preprocessing_stage.cc @@ -117,9 +117,9 @@ inline void Crop(ImageData* image_data, const CroppingParams& crop_params) { } else if (crop_params.has_target_size()) { crop_height = crop_params.target_size().height(); crop_width = crop_params.target_size().width(); - } else { - // Square cropping. - crop_height = std::min(input_height, input_width); + } + if (crop_params.has_cropping_fraction() && crop_params.square_cropping()) { + crop_height = std::min(crop_height, crop_width); crop_width = crop_height; } int start_w = static_cast(round((input_width - crop_width) / 2.0)); diff --git a/tensorflow/lite/tools/evaluation/stages/image_preprocessing_stage.h b/tensorflow/lite/tools/evaluation/stages/image_preprocessing_stage.h index 959248dab34..5056e5246c4 100644 --- a/tensorflow/lite/tools/evaluation/stages/image_preprocessing_stage.h +++ b/tensorflow/lite/tools/evaluation/stages/image_preprocessing_stage.h @@ -73,9 +73,11 @@ class ImagePreprocessingConfigBuilder { } // Adds a cropping step with cropping fraction. - void AddCroppingStep(float cropping_fraction) { + void AddCroppingStep(float cropping_fraction, + bool use_square_cropping = false) { ImagePreprocessingStepParams params; params.mutable_cropping_params()->set_cropping_fraction(cropping_fraction); + params.mutable_cropping_params()->set_square_cropping(use_square_cropping); config_.mutable_specification() ->mutable_image_preprocessing_params() ->mutable_steps() @@ -83,20 +85,12 @@ class ImagePreprocessingConfigBuilder { } // Adds a cropping step with target size. - void AddCroppingStep(uint32_t width, uint32_t height) { + void AddCroppingStep(uint32_t width, uint32_t height, + bool use_square_cropping = false) { ImagePreprocessingStepParams params; params.mutable_cropping_params()->mutable_target_size()->set_height(height); params.mutable_cropping_params()->mutable_target_size()->set_width(width); - config_.mutable_specification() - ->mutable_image_preprocessing_params() - ->mutable_steps() - ->Add(std::move(params)); - } - - // Adds a square cropping step. - void AddSquareCroppingStep() { - ImagePreprocessingStepParams params; - params.mutable_cropping_params()->set_square_cropping(true); + params.mutable_cropping_params()->set_square_cropping(use_square_cropping); config_.mutable_specification() ->mutable_image_preprocessing_params() ->mutable_steps()