Simplify preprocessing steps for image classification

PiperOrigin-RevId: 290878807
Change-Id: I0ea3c7629659f8dac49ee5a4153b9b3e458946d9
This commit is contained in:
Thai Nguyen 2020-01-21 20:19:03 -08:00 committed by TensorFlower Gardener
parent 1ce470efe7
commit 81c88fb3ab
4 changed files with 12 additions and 19 deletions

View File

@ -56,9 +56,9 @@ message CroppingParams {
float cropping_fraction = 1 [default = 0.875];
// The target size after cropping.
ImageSize target_size = 2;
// Crops to a square image.
bool square_cropping = 3;
}
// Crops to a square image.
optional bool square_cropping = 3;
}
// Defines parameters for bilinear central-resizing.

View File

@ -67,8 +67,7 @@ TfLiteStatus ImageClassificationStage::Init() {
// ImagePreprocessingStage
tflite::evaluation::ImagePreprocessingConfigBuilder builder(
"image_preprocessing", input_type);
builder.AddSquareCroppingStep();
builder.AddCroppingStep(kCroppingFraction);
builder.AddCroppingStep(kCroppingFraction, true /*square*/);
builder.AddResizingStep(input_shape->data[2], input_shape->data[1], false);
builder.AddDefaultNormalizationStep();
preprocessing_stage_.reset(new ImagePreprocessingStage(builder.build()));

View File

@ -117,9 +117,9 @@ inline void Crop(ImageData* image_data, const CroppingParams& crop_params) {
} else if (crop_params.has_target_size()) {
crop_height = crop_params.target_size().height();
crop_width = crop_params.target_size().width();
} else {
// Square cropping.
crop_height = std::min(input_height, input_width);
}
if (crop_params.has_cropping_fraction() && crop_params.square_cropping()) {
crop_height = std::min(crop_height, crop_width);
crop_width = crop_height;
}
int start_w = static_cast<int>(round((input_width - crop_width) / 2.0));

View File

@ -73,9 +73,11 @@ class ImagePreprocessingConfigBuilder {
}
// Adds a cropping step with cropping fraction.
void AddCroppingStep(float cropping_fraction) {
void AddCroppingStep(float cropping_fraction,
bool use_square_cropping = false) {
ImagePreprocessingStepParams params;
params.mutable_cropping_params()->set_cropping_fraction(cropping_fraction);
params.mutable_cropping_params()->set_square_cropping(use_square_cropping);
config_.mutable_specification()
->mutable_image_preprocessing_params()
->mutable_steps()
@ -83,20 +85,12 @@ class ImagePreprocessingConfigBuilder {
}
// Adds a cropping step with target size.
void AddCroppingStep(uint32_t width, uint32_t height) {
void AddCroppingStep(uint32_t width, uint32_t height,
bool use_square_cropping = false) {
ImagePreprocessingStepParams params;
params.mutable_cropping_params()->mutable_target_size()->set_height(height);
params.mutable_cropping_params()->mutable_target_size()->set_width(width);
config_.mutable_specification()
->mutable_image_preprocessing_params()
->mutable_steps()
->Add(std::move(params));
}
// Adds a square cropping step.
void AddSquareCroppingStep() {
ImagePreprocessingStepParams params;
params.mutable_cropping_params()->set_square_cropping(true);
params.mutable_cropping_params()->set_square_cropping(use_square_cropping);
config_.mutable_specification()
->mutable_image_preprocessing_params()
->mutable_steps()