diff --git a/tensorflow/core/kernels/crop_and_resize_op.cc b/tensorflow/core/kernels/crop_and_resize_op.cc index 4e50a041903..caf73420ba9 100644 --- a/tensorflow/core/kernels/crop_and_resize_op.cc +++ b/tensorflow/core/kernels/crop_and_resize_op.cc @@ -42,6 +42,10 @@ static inline void ParseAndCheckBoxSizes(OpKernelContext* context, const Tensor& boxes, const Tensor& box_ind, int* num_boxes) { + if (boxes.NumElements() == 0 && box_ind.NumElements() == 0) { + *num_boxes = 0; + return; + } // The shape of 'boxes' is [num_boxes, 4]. OP_REQUIRES(context, boxes.dims() == 2, errors::InvalidArgument("boxes must be 2-D", @@ -132,9 +136,13 @@ class CropAndResizeOp : public OpKernel { CheckValidBoxInd(context, box_ind_data, batch); - functor::CropAndResize()(context->eigen_device(), - image_data, boxes_data, box_ind_data, - extrapolation_value_, crops_data); + bool status = functor::CropAndResize()( + context->eigen_device(), image_data, boxes_data, box_ind_data, + extrapolation_value_, crops_data); + if (!status) { + context->SetStatus( + errors::Internal("Failed launch CropAndResizeKernel.")); + } } private: @@ -145,11 +153,12 @@ class CropAndResizeOp : public OpKernel { namespace functor { template struct CropAndResize { - void operator()(const CPUDevice& d, typename TTypes::ConstTensor image, + bool operator()(const CPUDevice& d, typename TTypes::ConstTensor image, typename TTypes::ConstTensor boxes, typename TTypes::ConstTensor box_ind, float extrapolation_value, typename TTypes::Tensor crops) { + const int batch = image.dimension(0); const int image_height = image.dimension(1); const int image_width = image.dimension(2); @@ -163,7 +172,11 @@ struct CropAndResize { const float x1 = boxes(b, 1); const float y2 = boxes(b, 2); const float x2 = boxes(b, 3); + const int32 b_in = box_ind(b); + if (b_in < 0 || b_in >= batch) { + continue; + } const float height_scale = (crop_height > 1) ? (y2 - y1) * (image_height - 1) / (crop_height - 1) @@ -217,6 +230,7 @@ struct CropAndResize { } } } + return true; } }; } // namespace functor @@ -235,6 +249,7 @@ class CropAndResizeGradImageOp : public OpKernel { void Compute(OpKernelContext* context) override { // The shape of 'grads' is [num_boxes, crop_height, crop_width, depth]. const Tensor& grads = context->input(0); + OP_REQUIRES(context, grads.dims() == 4, errors::InvalidArgument("grads image must be 4-D", grads.shape().DebugString())); @@ -294,9 +309,13 @@ class CropAndResizeGradImageOp : public OpKernel { CheckValidBoxInd(context, box_ind_data, batch); - functor::CropAndResizeBackpropImage()( + bool status = functor::CropAndResizeBackpropImage()( context->eigen_device(), grads_data, boxes_data, box_ind_data, output_data); + if (!status) { + context->SetStatus( + errors::Internal("Failed launch CropAndResizeBackpropImageKernel.")); + } } }; @@ -304,11 +323,12 @@ class CropAndResizeGradImageOp : public OpKernel { namespace functor { template struct CropAndResizeBackpropImage { - void operator()(const CPUDevice& d, + bool operator()(const CPUDevice& d, typename TTypes::ConstTensor grads, typename TTypes::ConstTensor boxes, typename TTypes::ConstTensor box_ind, typename TTypes::Tensor grads_image) { + const int batch = grads_image.dimension(0); const int image_height = grads_image.dimension(1); const int image_width = grads_image.dimension(2); @@ -324,7 +344,11 @@ struct CropAndResizeBackpropImage { const float x1 = boxes(b, 1); const float y2 = boxes(b, 2); const float x2 = boxes(b, 3); + const int32 b_in = box_ind(b); + if (b_in < 0 || b_in >= batch) { + continue; + } const float height_scale = (crop_height > 1) ? (y2 - y1) * (image_height - 1) / (crop_height - 1) @@ -370,6 +394,7 @@ struct CropAndResizeBackpropImage { } } } + return true; } }; } // namespace functor @@ -388,6 +413,7 @@ class CropAndResizeGradBoxesOp : public OpKernel { void Compute(OpKernelContext* context) override { // The shape of 'grads' is [num_boxes, crop_height, crop_width, depth]. const Tensor& grads = context->input(0); + OP_REQUIRES(context, grads.dims() == 4, errors::InvalidArgument("grads image must be 4-D", grads.shape().DebugString())); @@ -441,9 +467,13 @@ class CropAndResizeGradBoxesOp : public OpKernel { CheckValidBoxInd(context, box_ind_data, batch); - functor::CropAndResizeBackpropBoxes()( + bool status = functor::CropAndResizeBackpropBoxes()( context->eigen_device(), grads_data, image_data, boxes_data, box_ind_data, output_data); + if (!status) { + context->SetStatus( + errors::Internal("Failed launch CropAndResizeBackpropBoxesKernel.")); + } } }; @@ -451,12 +481,13 @@ class CropAndResizeGradBoxesOp : public OpKernel { namespace functor { template struct CropAndResizeBackpropBoxes { - void operator()(const CPUDevice& d, + bool operator()(const CPUDevice& d, typename TTypes::ConstTensor grads, typename TTypes::ConstTensor image, typename TTypes::ConstTensor boxes, typename TTypes::ConstTensor box_ind, typename TTypes::Tensor grads_boxes) { + const int batch = image.dimension(0); const int image_height = image.dimension(1); const int image_width = image.dimension(2); @@ -472,7 +503,11 @@ struct CropAndResizeBackpropBoxes { const float x1 = boxes(b, 1); const float y2 = boxes(b, 2); const float x2 = boxes(b, 3); + const int32 b_in = box_ind(b); + if (b_in < 0 || b_in >= batch) { + continue; + } const float height_ratio = (crop_height > 1) @@ -547,6 +582,7 @@ struct CropAndResizeBackpropBoxes { } } } + return true; } }; } // namespace functor @@ -563,37 +599,25 @@ inline void CheckValidBoxInd( } } -#define REGISTER_KERNEL(T) \ - REGISTER_KERNEL_BUILDER(Name("CropAndResize") \ - .Device(DEVICE_CPU) \ - .TypeConstraint("T") \ - .HostMemory("crop_size"), \ - CropAndResizeOp); - -TF_CALL_REAL_NUMBER_TYPES(REGISTER_KERNEL); - -#undef REGISTER_KERNEL - -#define REGISTER_KERNEL(T) \ - REGISTER_KERNEL_BUILDER(Name("CropAndResizeGradImage") \ - .Device(DEVICE_CPU) \ - .TypeConstraint("T") \ - .HostMemory("image_size"), \ - CropAndResizeGradImageOp); - -TF_CALL_half(REGISTER_KERNEL); -TF_CALL_float(REGISTER_KERNEL); -TF_CALL_double(REGISTER_KERNEL); - -#undef REGISTER_KERNEL - -#define REGISTER_KERNEL(T) \ - REGISTER_KERNEL_BUILDER(Name("CropAndResizeGradBoxes") \ - .Device(DEVICE_CPU) \ - .TypeConstraint("T"), \ +#define REGISTER_KERNEL(T) \ + REGISTER_KERNEL_BUILDER(Name("CropAndResize") \ + .Device(DEVICE_CPU) \ + .TypeConstraint("T") \ + .HostMemory("crop_size"), \ + CropAndResizeOp); \ + \ + REGISTER_KERNEL_BUILDER(Name("CropAndResizeGradImage") \ + .Device(DEVICE_CPU) \ + .TypeConstraint("T") \ + .HostMemory("image_size"), \ + CropAndResizeGradImageOp); \ + \ + REGISTER_KERNEL_BUILDER(Name("CropAndResizeGradBoxes") \ + .Device(DEVICE_CPU) \ + .TypeConstraint("T"), \ CropAndResizeGradBoxesOp); -TF_CALL_REAL_NUMBER_TYPES(REGISTER_KERNEL); +TF_CALL_float(REGISTER_KERNEL); #undef REGISTER_KERNEL @@ -613,6 +637,10 @@ template <> inline void CheckValidBoxInd( OpKernelContext* context, typename TTypes::ConstTensor box_ind, int batch) { + const int num_boxes = box_ind.dimension(0); + if (num_boxes == 0) { + return; + } Tensor isvalid_tensor; OP_REQUIRES_OK(context, context->allocate_temp(DataTypeToEnum::value, @@ -657,7 +685,7 @@ inline void CheckValidBoxInd( .TypeConstraint("T"), \ CropAndResizeGradBoxesOp); -TF_CALL_GPU_NUMBER_TYPES_NO_HALF(REGISTER_KERNEL); +TF_CALL_float(REGISTER_KERNEL); #undef REGISTER_KERNEL diff --git a/tensorflow/core/kernels/crop_and_resize_op.h b/tensorflow/core/kernels/crop_and_resize_op.h index 9278893704e..22df1bdd56b 100644 --- a/tensorflow/core/kernels/crop_and_resize_op.h +++ b/tensorflow/core/kernels/crop_and_resize_op.h @@ -26,7 +26,7 @@ namespace functor { template struct CropAndResize { // We assume that the tensor sizes are correct. - void operator()(const Device& d, typename TTypes::ConstTensor image, + bool operator()(const Device& d, typename TTypes::ConstTensor image, typename TTypes::ConstTensor boxes, typename TTypes::ConstTensor box_ind, float extrapolation_value, @@ -36,7 +36,7 @@ struct CropAndResize { template struct CropAndResizeBackpropImage { // We assume that the tensor sizes are correct. - void operator()(const Device& d, typename TTypes::ConstTensor grads, + bool operator()(const Device& d, typename TTypes::ConstTensor grads, typename TTypes::ConstTensor boxes, typename TTypes::ConstTensor box_ind, typename TTypes::Tensor grads_image); @@ -45,7 +45,7 @@ struct CropAndResizeBackpropImage { template struct CropAndResizeBackpropBoxes { // We assume that the tensor sizes are correct. - void operator()(const Device& d, typename TTypes::ConstTensor grads, + bool operator()(const Device& d, typename TTypes::ConstTensor grads, typename TTypes::ConstTensor image, typename TTypes::ConstTensor boxes, typename TTypes::ConstTensor box_ind, diff --git a/tensorflow/core/kernels/crop_and_resize_op_gpu.cu.cc b/tensorflow/core/kernels/crop_and_resize_op_gpu.cu.cc index 3759a8cb4ce..75146b28e66 100644 --- a/tensorflow/core/kernels/crop_and_resize_op_gpu.cu.cc +++ b/tensorflow/core/kernels/crop_and_resize_op_gpu.cu.cc @@ -33,27 +33,30 @@ typedef Eigen::GpuDevice GPUDevice; namespace { template -__global__ void CropAndResizeKernel(const int32 nthreads, const T* image_ptr, - const float* boxes_ptr, - const int32* box_ind_ptr, int num_boxes, - int image_height, int image_width, - int crop_height, int crop_width, int depth, - float extrapolation_value, - float* crops_ptr) { +__global__ void CropAndResizeKernel( + const int32 nthreads, const T* image_ptr, const float* boxes_ptr, + const int32* box_ind_ptr, int num_boxes, int batch, int image_height, + int image_width, int crop_height, int crop_width, int depth, + float extrapolation_value, float* crops_ptr) { CUDA_1D_KERNEL_LOOP(out_idx, nthreads) { // out_idx = d + depth * (w + crop_width * (h + crop_height * b)) - const int d = out_idx % depth; - const int out_idx2 = out_idx / depth; - const int x = out_idx2 % crop_width; - const int out_idx3 = out_idx2 / crop_width; - const int y = out_idx3 % crop_height; - const int b = out_idx3 / crop_height; + int idx = out_idx; + const int d = idx % depth; + idx /= depth; + const int x = idx % crop_width; + idx /= crop_width; + const int y = idx % crop_height; + const int b = idx / crop_height; const float y1 = boxes_ptr[b * 4]; const float x1 = boxes_ptr[b * 4 + 1]; const float y2 = boxes_ptr[b * 4 + 2]; const float x2 = boxes_ptr[b * 4 + 3]; + const int32 b_in = box_ind_ptr[b]; + if (b_in < 0 || b_in >= batch) { + continue; + } const float height_scale = (crop_height > 1) ? (y2 - y1) * (image_height - 1) / (crop_height - 1) @@ -66,7 +69,7 @@ __global__ void CropAndResizeKernel(const int32 nthreads, const T* image_ptr, : 0.5 * (y1 + y2) * (image_height - 1); if (in_y < 0 || in_y > image_height - 1) { crops_ptr[out_idx] = extrapolation_value; - return; + continue; } const float in_x = (crop_width > 1) @@ -74,7 +77,7 @@ __global__ void CropAndResizeKernel(const int32 nthreads, const T* image_ptr, : 0.5 * (x1 + x2) * (image_width - 1); if (in_x < 0 || in_x > image_width - 1) { crops_ptr[out_idx] = extrapolation_value; - return; + continue; } const int top_y_index = floorf(in_y); @@ -114,22 +117,28 @@ __global__ void CropAndResizeKernel(const int32 nthreads, const T* image_ptr, template __global__ void CropAndResizeBackpropImageKernel( const int32 nthreads, const float* grads_ptr, const float* boxes_ptr, - const int32* box_ind_ptr, int num_boxes, int image_height, int image_width, - int crop_height, int crop_width, int depth, T* grads_image_ptr) { + const int32* box_ind_ptr, int num_boxes, int batch, int image_height, + int image_width, int crop_height, int crop_width, int depth, + T* grads_image_ptr) { CUDA_1D_KERNEL_LOOP(out_idx, nthreads) { // out_idx = d + depth * (w + crop_width * (h + crop_height * b)) - const int d = out_idx % depth; - const int out_idx2 = out_idx / depth; - const int x = out_idx2 % crop_width; - const int out_idx3 = out_idx2 / crop_width; - const int y = out_idx3 % crop_height; - const int b = out_idx3 / crop_height; + int idx = out_idx; + const int d = idx % depth; + idx /= depth; + const int x = idx % crop_width; + idx /= crop_width; + const int y = idx % crop_height; + const int b = idx / crop_height; const float y1 = boxes_ptr[b * 4]; const float x1 = boxes_ptr[b * 4 + 1]; const float y2 = boxes_ptr[b * 4 + 2]; const float x2 = boxes_ptr[b * 4 + 3]; + const int32 b_in = box_ind_ptr[b]; + if (b_in < 0 || b_in >= batch) { + continue; + } const float height_scale = (crop_height > 1) ? (y2 - y1) * (image_height - 1) / (crop_height - 1) @@ -141,14 +150,14 @@ __global__ void CropAndResizeBackpropImageKernel( ? y1 * (image_height - 1) + y * height_scale : 0.5 * (y1 + y2) * (image_height - 1); if (in_y < 0 || in_y > image_height - 1) { - return; + continue; } const float in_x = (crop_width > 1) ? x1 * (image_width - 1) + x * width_scale : 0.5 * (x1 + x2) * (image_width - 1); if (in_x < 0 || in_x > image_width - 1) { - return; + continue; } const int top_y_index = floorf(in_y); @@ -192,23 +201,28 @@ __global__ void CropAndResizeBackpropImageKernel( template __global__ void CropAndResizeBackpropBoxesKernel( const int32 nthreads, const float* grads_ptr, const T* image_ptr, - const float* boxes_ptr, const int32* box_ind_ptr, int num_boxes, + const float* boxes_ptr, const int32* box_ind_ptr, int num_boxes, int batch, int image_height, int image_width, int crop_height, int crop_width, int depth, float* grads_boxes_ptr) { CUDA_1D_KERNEL_LOOP(out_idx, nthreads) { // out_idx = d + depth * (w + crop_width * (h + crop_height * b)) - const int d = out_idx % depth; - const int out_idx2 = out_idx / depth; - const int x = out_idx2 % crop_width; - const int out_idx3 = out_idx2 / crop_width; - const int y = out_idx3 % crop_height; - const int b = out_idx3 / crop_height; + int idx = out_idx; + const int d = idx % depth; + idx /= depth; + const int x = idx % crop_width; + idx /= crop_width; + const int y = idx % crop_height; + const int b = idx / crop_height; const float y1 = boxes_ptr[b * 4]; const float x1 = boxes_ptr[b * 4 + 1]; const float y2 = boxes_ptr[b * 4 + 2]; const float x2 = boxes_ptr[b * 4 + 3]; + const int32 b_in = box_ind_ptr[b]; + if (b_in < 0 || b_in >= batch) { + continue; + } const float height_ratio = (crop_height > 1) @@ -226,14 +240,14 @@ __global__ void CropAndResizeBackpropBoxesKernel( ? y1 * (image_height - 1) + y * height_scale : 0.5 * (y1 + y2) * (image_height - 1); if (in_y < 0 || in_y > image_height - 1) { - return; + continue; } const float in_x = (crop_width > 1) ? x1 * (image_width - 1) + x * width_scale : 0.5 * (x1 + x2) * (image_width - 1); if (in_x < 0 || in_x > image_width - 1) { - return; + continue; } const int top_y_index = floorf(in_y); @@ -306,11 +320,12 @@ namespace functor { template struct CropAndResize { - void operator()(const GPUDevice& d, typename TTypes::ConstTensor image, + bool operator()(const GPUDevice& d, typename TTypes::ConstTensor image, typename TTypes::ConstTensor boxes, typename TTypes::ConstTensor box_ind, float extrapolation_value, typename TTypes::Tensor crops) { + const int batch = image.dimension(0); const int image_height = image.dimension(1); const int image_width = image.dimension(2); @@ -320,19 +335,22 @@ struct CropAndResize { const int depth = crops.dimension(3); const int total_count = num_boxes * crop_height * crop_width * depth; - CudaLaunchConfig config = GetCudaLaunchConfig(total_count, d); - CropAndResizeKernel<<>>( - config.virtual_thread_count, image.data(), boxes.data(), box_ind.data(), - num_boxes, image_height, image_width, crop_height, crop_width, depth, - extrapolation_value, crops.data()); + if (total_count > 0) { + CudaLaunchConfig config = GetCudaLaunchConfig(total_count, d); + CropAndResizeKernel<<>>( + config.virtual_thread_count, image.data(), boxes.data(), + box_ind.data(), num_boxes, batch, image_height, image_width, + crop_height, crop_width, depth, extrapolation_value, crops.data()); + } + return d.ok(); } }; template struct CropAndResizeBackpropImage { - void operator()(const GPUDevice& d, + bool operator()(const GPUDevice& d, typename TTypes::ConstTensor grads, typename TTypes::ConstTensor boxes, typename TTypes::ConstTensor box_ind, @@ -351,29 +369,35 @@ struct CropAndResizeBackpropImage { // Initialize grads_image with all zeros. total_count = batch * image_height * image_width * depth; - config = GetCudaLaunchConfig(total_count, d); - SetZero<<>>( - total_count, grads_image.data()); + if (total_count > 0) { + config = GetCudaLaunchConfig(total_count, d); + SetZero<<>>( + config.virtual_thread_count, grads_image.data()); + } // Accumulate. total_count = num_boxes * crop_height * crop_width * depth; - config = GetCudaLaunchConfig(total_count, d); - CropAndResizeBackpropImageKernel<<< - config.block_count, config.thread_per_block, 0, d.stream()>>>( - config.virtual_thread_count, grads.data(), boxes.data(), box_ind.data(), - num_boxes, image_height, image_width, crop_height, crop_width, depth, - grads_image.data()); + if (total_count > 0) { + config = GetCudaLaunchConfig(total_count, d); + CropAndResizeBackpropImageKernel<<< + config.block_count, config.thread_per_block, 0, d.stream()>>>( + config.virtual_thread_count, grads.data(), boxes.data(), + box_ind.data(), num_boxes, batch, image_height, image_width, + crop_height, crop_width, depth, grads_image.data()); + } + return d.ok(); } }; template struct CropAndResizeBackpropBoxes { - void operator()(const GPUDevice& d, + bool operator()(const GPUDevice& d, typename TTypes::ConstTensor grads, typename TTypes::ConstTensor image, typename TTypes::ConstTensor boxes, typename TTypes::ConstTensor box_ind, typename TTypes::Tensor grads_boxes) { + const int batch = image.dimension(0); const int image_height = image.dimension(1); const int image_width = image.dimension(2); @@ -387,18 +411,23 @@ struct CropAndResizeBackpropBoxes { // Initialize grads_boxes with all zeros. total_count = num_boxes * 4; - config = GetCudaLaunchConfig(total_count, d); - SetZero<<>>( - total_count, grads_boxes.data()); + if (total_count > 0) { + config = GetCudaLaunchConfig(total_count, d); + SetZero<<>>( + config.virtual_thread_count, grads_boxes.data()); + } // Accumulate. total_count = num_boxes * crop_height * crop_width * depth; - config = GetCudaLaunchConfig(total_count, d); - CropAndResizeBackpropBoxesKernel<<< - config.block_count, config.thread_per_block, 0, d.stream()>>>( - config.virtual_thread_count, grads.data(), image.data(), boxes.data(), - box_ind.data(), num_boxes, image_height, image_width, crop_height, - crop_width, depth, grads_boxes.data()); + if (total_count > 0) { + config = GetCudaLaunchConfig(total_count, d); + CropAndResizeBackpropBoxesKernel<<< + config.block_count, config.thread_per_block, 0, d.stream()>>>( + config.virtual_thread_count, grads.data(), image.data(), boxes.data(), + box_ind.data(), num_boxes, batch, image_height, image_width, + crop_height, crop_width, depth, grads_boxes.data()); + } + return d.ok(); } }; @@ -407,7 +436,7 @@ struct CropAndResizeBackpropBoxes { template struct CropAndResizeBackpropImage; \ template struct CropAndResizeBackpropBoxes; -TF_CALL_GPU_NUMBER_TYPES_NO_HALF(DEFINE_GPU_SPECS); +TF_CALL_float(DEFINE_GPU_SPECS); #undef DEFINE_GPU_SPECS diff --git a/tensorflow/core/kernels/crop_and_resize_op_test.cc b/tensorflow/core/kernels/crop_and_resize_op_test.cc index 38f3c1adb2a..68e077e44df 100644 --- a/tensorflow/core/kernels/crop_and_resize_op_test.cc +++ b/tensorflow/core/kernels/crop_and_resize_op_test.cc @@ -189,6 +189,24 @@ TEST_F(CropAndResizeOpTest, TestCropAndResize2x2To3x3Extrapolated) { test::ExpectTensorEqual(expected, *GetOutput(0)); } +TEST_F(CropAndResizeOpTest, TestCropAndResize2x2To3x3NoCrop) { + MakeOp(0); + // Input: + // 1, 2 + // 3, 4 + AddInputFromArray(TensorShape({1, 2, 2, 1}), {1, 2, 3, 4}); + AddInputFromArray(TensorShape({0, 4}), {}); + AddInputFromArray(TensorShape({0}), {}); + AddInputFromArray(TensorShape({2}), {3, 3}); + TF_ASSERT_OK(RunOpKernel()); + + Tensor expected(allocator(), DT_FLOAT, TensorShape({0, 3, 3, 1})); + // clang-format off + test::FillValues(&expected, {}); + // clang-format on + test::ExpectTensorEqual(expected, *GetOutput(0)); +} + TEST_F(CropAndResizeOpTest, TestInvalidInputShape) { MakeOp(0); AddInputFromArray(TensorShape({2, 2, 1}), {1, 2, 3, 4}); @@ -201,6 +219,19 @@ TEST_F(CropAndResizeOpTest, TestInvalidInputShape) { << s; } +TEST_F(CropAndResizeOpTest, TestInvalidBoxIndexShape) { + MakeOp(0); + AddInputFromArray(TensorShape({1, 2, 2, 1}), {1, 2, 3, 4}); + AddInputFromArray(TensorShape({1, 4}), {0, 0, 1, 1}); + AddInputFromArray(TensorShape({2}), {0, 0}); + AddInputFromArray(TensorShape({2}), {4, 4}); + Status s = RunOpKernel(); + ASSERT_FALSE(s.ok()); + EXPECT_TRUE( + StringPiece(s.ToString()).contains("box_ind has incompatible shape")) + << s; +} + TEST_F(CropAndResizeOpTest, TestInvalidBoxIndex) { MakeOp(0); AddInputFromArray(TensorShape({1, 2, 2, 1}), {1, 2, 3, 4}); diff --git a/tensorflow/python/ops/image_grad_test.py b/tensorflow/python/ops/image_grad_test.py index dab96194247..ccbae0f1572 100644 --- a/tensorflow/python/ops/image_grad_test.py +++ b/tensorflow/python/ops/image_grad_test.py @@ -253,54 +253,53 @@ class CropAndResizeOpTest(tf.test.TestCase): radius = 2 * delta low, high = -0.5, 1.5 # Also covers the case of extrapolation. - for image_height in range(1, 5): - for image_width in range(1, 3): - for crop_height in range(1, 3): - for crop_width in range(2, 4): - for depth in range(1, 3): - for num_boxes in range(1, 3): + image_height = 4 + for image_width in range(1, 3): + for crop_height in range(1, 3): + for crop_width in range(2, 4): + for depth in range(1, 3): + for num_boxes in range(1, 3): - batch = num_boxes - image_shape = [batch, image_height, image_width, depth] - crop_size = [crop_height, crop_width] - crops_shape = [num_boxes, crop_height, crop_width, depth] - boxes_shape = [num_boxes, 4] + batch = num_boxes + image_shape = [batch, image_height, image_width, depth] + crop_size = [crop_height, crop_width] + crops_shape = [num_boxes, crop_height, crop_width, depth] + boxes_shape = [num_boxes, 4] - image = np.arange(0, batch * image_height * image_width * - depth).reshape(image_shape).astype(np.float32) - boxes = [] - for _ in range(num_boxes): - # pylint: disable=unbalanced-tuple-unpacking - y1, y2 = self._randomUniformAvoidAnchors( - low, high, np.linspace(0, 1, image_height), radius, 2) - x1, x2 = self._randomUniformAvoidAnchors( - low, high, np.linspace(0, 1, image_width), radius, 2) - # pylint: enable=unbalanced-tuple-unpacking - boxes.append([y1, x1, y2, x2]) + image = np.arange(0, batch * image_height * image_width * + depth).reshape(image_shape).astype(np.float32) + boxes = [] + for _ in range(num_boxes): + # pylint: disable=unbalanced-tuple-unpacking + y1, y2 = self._randomUniformAvoidAnchors( + low, high, np.linspace(0, 1, image_height), radius, 2) + x1, x2 = self._randomUniformAvoidAnchors( + low, high, np.linspace(0, 1, image_width), radius, 2) + # pylint: enable=unbalanced-tuple-unpacking + boxes.append([y1, x1, y2, x2]) - boxes = np.array(boxes, dtype=np.float32) - box_ind = np.arange(batch, dtype=np.int32) + boxes = np.array(boxes, dtype=np.float32) + box_ind = np.arange(batch, dtype=np.int32) - for use_gpu in [False, True]: - with self.test_session(use_gpu=use_gpu): - image_tensor = tf.constant(image, shape=image_shape) - boxes_tensor = tf.constant(boxes, shape=[num_boxes, 4]) - box_ind_tensor = tf.constant(box_ind, shape=[num_boxes]) - crops = tf.image.crop_and_resize( - image_tensor, - boxes_tensor, - box_ind_tensor, - tf.constant(crop_size, shape=[2])) + for use_gpu in [False, True]: + with self.test_session(use_gpu=use_gpu): + image_tensor = tf.constant(image, shape=image_shape) + boxes_tensor = tf.constant(boxes, shape=[num_boxes, 4]) + box_ind_tensor = tf.constant(box_ind, shape=[num_boxes]) + crops = tf.image.crop_and_resize( + image_tensor, + boxes_tensor, + box_ind_tensor, + tf.constant(crop_size, shape=[2])) - err = tf.test.compute_gradient_error( - [image_tensor, boxes_tensor], - [image_shape, boxes_shape], - crops, - crops_shape, - delta=delta, - x_init_value=[image, boxes]) + err = tf.test.compute_gradient_error( + [image_tensor, boxes_tensor], [image_shape, boxes_shape], + crops, + crops_shape, + delta=delta, + x_init_value=[image, boxes]) - self.assertLess(err, 2e-3) + self.assertLess(err, 2e-3) if __name__ == "__main__":