diff --git a/tensorflow/core/kernels/resize_bicubic_op.cc b/tensorflow/core/kernels/resize_bicubic_op.cc index 1c43e77e7c2..1a9cf4c6406 100644 --- a/tensorflow/core/kernels/resize_bicubic_op.cc +++ b/tensorflow/core/kernels/resize_bicubic_op.cc @@ -20,7 +20,6 @@ limitations under the License. #include #include -#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/register_types.h" #include "tensorflow/core/framework/tensor.h" @@ -29,6 +28,7 @@ limitations under the License. #include "tensorflow/core/kernels/image_resizer_state.h" #include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/platform/logging.h" +#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" namespace tensorflow { namespace { @@ -235,6 +235,7 @@ inline void interpolate_with_caching( const T* input_b_ptr = input_data.data(); float* output_y_ptr = output_data.data(); + std::vector cached_value(num_channels == 3 ? 0 : 4 * num_channels, 0); for (int64 b = 0; b < resizer_state.batch_size; ++b, input_b_ptr += in_batch_width) { @@ -248,6 +249,7 @@ inline void interpolate_with_caching( const T* y_ptr_1 = input_b_ptr + y_wai.index_1 * in_row_width; const T* y_ptr_2 = input_b_ptr + y_wai.index_2 * in_row_width; const T* y_ptr_3 = input_b_ptr + y_wai.index_3 * in_row_width; + if (num_channels == 3) { // Manually unroll case of 3 channels. float cached_value_0[4] = {0}; @@ -330,48 +332,61 @@ inline void interpolate_with_caching( x_wai.weight_2, x_wai.weight_3); } } else { - for (int64 c = 0; c < num_channels; ++c) { - float cached_value[4] = {0}; - for (int64 x = 0; x < resizer_state.out_width; ++x) { - const WeightsAndIndices& x_wai = x_wais[x]; - // Shift values in cached_value to fill first 'advance' values. - switch (x_wai.advance) { - case 3: - cached_value[0] = cached_value[1]; - cached_value[1] = cached_value[2]; - cached_value[2] = cached_value[3]; - break; - case 2: - cached_value[0] = cached_value[2]; - cached_value[1] = cached_value[3]; - break; - case 1: { - cached_value[0] = cached_value[3]; - break; + for (int64 x = 0; x < resizer_state.out_width; ++x) { + const WeightsAndIndices& x_wai = x_wais[x]; + // Shift values in cached_value to fill first 'advance' values. + switch (x_wai.advance) { + case 3: + for (int64 c = 0; c < num_channels; ++c) { + cached_value[4 * c + 0] = cached_value[4 * c + 1]; + cached_value[4 * c + 1] = cached_value[4 * c + 2]; + cached_value[4 * c + 2] = cached_value[4 * c + 3]; } + break; + case 2: + for (int64 c = 0; c < num_channels; ++c) { + cached_value[4 * c + 0] = cached_value[4 * c + 2]; + cached_value[4 * c + 1] = cached_value[4 * c + 3]; + } + break; + case 1: { + for (int64 c = 0; c < num_channels; ++c) { + cached_value[4 * c + 0] = cached_value[4 * c + 3]; + } + break; } + } - // Set the remaining '4-advance' values by computing. - switch (x_wai.advance) { - case 0: - cached_value[0] = ComputeYInterpolation( + // Set the remaining '4-advance' values by computing. + switch (x_wai.advance) { + case 0: + for (int64 c = 0; c < num_channels; ++c) { + cached_value[4 * c + 0] = ComputeYInterpolation( 0, c, y_wai, y_ptr_0, y_ptr_1, y_ptr_2, y_ptr_3, x_wai); - TF_FALLTHROUGH_INTENDED; - case 1: - cached_value[1] = ComputeYInterpolation( + } + TF_FALLTHROUGH_INTENDED; + case 1: + for (int64 c = 0; c < num_channels; ++c) { + cached_value[4 * c + 1] = ComputeYInterpolation( 1, c, y_wai, y_ptr_0, y_ptr_1, y_ptr_2, y_ptr_3, x_wai); - TF_FALLTHROUGH_INTENDED; - case 2: - cached_value[2] = ComputeYInterpolation( + } + TF_FALLTHROUGH_INTENDED; + case 2: + for (int64 c = 0; c < num_channels; ++c) { + cached_value[4 * c + 2] = ComputeYInterpolation( 2, c, y_wai, y_ptr_0, y_ptr_1, y_ptr_2, y_ptr_3, x_wai); - TF_FALLTHROUGH_INTENDED; - case 3: - cached_value[3] = ComputeYInterpolation( + } + TF_FALLTHROUGH_INTENDED; + case 3: + for (int64 c = 0; c < num_channels; ++c) { + cached_value[4 * c + 3] = ComputeYInterpolation( 3, c, y_wai, y_ptr_0, y_ptr_1, y_ptr_2, y_ptr_3, x_wai); - break; - } + } + break; + } + for (int64 c = 0; c < num_channels; ++c) { output_y_ptr[x * num_channels + c] = - Compute(cached_value, x_wai.weight_0, x_wai.weight_1, + Compute(&cached_value[4 * c], x_wai.weight_0, x_wai.weight_1, x_wai.weight_2, x_wai.weight_3); } } diff --git a/tensorflow/core/kernels/resize_bicubic_op_test.cc b/tensorflow/core/kernels/resize_bicubic_op_test.cc index ae14d2804e2..9e10fec4232 100644 --- a/tensorflow/core/kernels/resize_bicubic_op_test.cc +++ b/tensorflow/core/kernels/resize_bicubic_op_test.cc @@ -251,14 +251,15 @@ TEST_F(ResizeBicubicOpTest, TestAreaRandomDataSeveralInputsSizes4Channels) { RunManyRandomTests(4); } -static Graph* ResizeBicubic(int batch_size, int size, int channels) { +static Graph* ResizeBicubic(int batch_size, int size, int channels, + float scale_y = 0.3, float scale_x = 0.7) { Graph* g = new Graph(OpRegistry::Global()); Tensor input(DT_FLOAT, TensorShape({batch_size, size, size, channels})); input.flat().setRandom(); Tensor shape(DT_INT32, TensorShape({2})); auto shape_t = shape.flat(); - shape_t(0) = 0.3 * size; - shape_t(1) = 0.7 * size; + shape_t(0) = scale_y * size; + shape_t(1) = scale_x * size; test::graph::Binary(g, "ResizeBicubic", test::graph::Constant(g, input), test::graph::Constant(g, shape)); return g; @@ -285,4 +286,17 @@ BM_ResizeBicubicDev(32, 128, 3); BM_ResizeBicubicDev(32, 512, 3); BM_ResizeBicubicDev(32, 1024, 3); +#define BM_ResizeBicubicExpand(BATCH, SIZE, CHANNELS) \ + static void BM_ResizeBicubicExpand##_##BATCH##_##SIZE##_##CHANNELS(int iters) { \ + testing::ItemsProcessed(static_cast(iters) * BATCH * SIZE * SIZE * \ + CHANNELS * 8 * 8); \ + test::Benchmark("cpu", ResizeBicubic(BATCH, SIZE, CHANNELS, 8, 8)) \ + .Run(iters); \ + } \ + BENCHMARK(BM_ResizeBicubicExpand##_##BATCH##_##SIZE##_##CHANNELS); + +BM_ResizeBicubicExpand(12, 48, 1); +BM_ResizeBicubicExpand(12, 48, 3); +BM_ResizeBicubicExpand(12, 48, 40); + } // end namespace tensorflow