Fix shape populating.

PiperOrigin-RevId: 248082580
This commit is contained in:
Renjie Liu 2019-05-14 00:01:33 -07:00 committed by TensorFlower Gardener
parent 27cfc61581
commit 78306dd2be
2 changed files with 8 additions and 4 deletions

View File

@ -63,12 +63,14 @@ void CreateDataAndRunAveragePool(bool padding_same) {
const int batch = UniformRandomInt(1, 2); const int batch = UniformRandomInt(1, 2);
const int input_depth = UniformRandomInt(1, 700); const int input_depth = UniformRandomInt(1, 700);
const int output_depth = input_depth; const int output_depth = input_depth;
const int input_width = UniformRandomInt(1, 30); const int input_width_offset = UniformRandomInt(1, 30);
const int input_height = UniformRandomInt(1, 30); const int input_height_offset = UniformRandomInt(1, 30);
const int stride_width = UniformRandomInt(1, 10); const int stride_width = UniformRandomInt(1, 10);
const int stride_height = UniformRandomInt(1, 10); const int stride_height = UniformRandomInt(1, 10);
const int filter_width = UniformRandomInt(1, 10); const int filter_width = UniformRandomInt(1, 10);
const int filter_height = UniformRandomInt(1, 10); const int filter_height = UniformRandomInt(1, 10);
const int input_width = input_width_offset + filter_width;
const int input_height = input_height_offset + filter_height;
const int output_width = const int output_width =
padding_same ? (input_width + stride_width - 1) / stride_width padding_same ? (input_width + stride_width - 1) / stride_width
: (input_width - filter_width + stride_width) / stride_width; : (input_width - filter_width + stride_width) / stride_width;

View File

@ -59,12 +59,14 @@ void CreateDataAndRunMaxPool(bool padding_same) {
const int batch = UniformRandomInt(1, 2); const int batch = UniformRandomInt(1, 2);
const int input_depth = UniformRandomInt(1, 700); const int input_depth = UniformRandomInt(1, 700);
const int output_depth = input_depth; const int output_depth = input_depth;
const int input_width = UniformRandomInt(1, 30); const int input_width_offset = UniformRandomInt(1, 30);
const int input_height = UniformRandomInt(1, 30); const int input_height_offset = UniformRandomInt(1, 30);
const int stride_width = UniformRandomInt(1, 10); const int stride_width = UniformRandomInt(1, 10);
const int stride_height = UniformRandomInt(1, 10); const int stride_height = UniformRandomInt(1, 10);
const int filter_width = UniformRandomInt(1, 10); const int filter_width = UniformRandomInt(1, 10);
const int filter_height = UniformRandomInt(1, 10); const int filter_height = UniformRandomInt(1, 10);
const int input_width = input_width_offset + filter_width;
const int input_height = input_height_offset + filter_height;
const int output_width = const int output_width =
padding_same ? (input_width + stride_width - 1) / stride_width padding_same ? (input_width + stride_width - 1) / stride_width
: (input_width - filter_width + stride_width) / stride_width; : (input_width - filter_width + stride_width) / stride_width;