Correct bug in crop_and_resize which caused failures to some tests.

Change: 126246458
This commit is contained in:
A. Unique TensorFlower 2016-06-29 15:32:57 -08:00 committed by TensorFlower Gardener
parent e8974bac93
commit 1d92cfcbf5
5 changed files with 232 additions and 145 deletions

View File

@ -42,6 +42,10 @@ static inline void ParseAndCheckBoxSizes(OpKernelContext* context,
const Tensor& boxes,
const Tensor& box_ind,
int* num_boxes) {
if (boxes.NumElements() == 0 && box_ind.NumElements() == 0) {
*num_boxes = 0;
return;
}
// The shape of 'boxes' is [num_boxes, 4].
OP_REQUIRES(context, boxes.dims() == 2,
errors::InvalidArgument("boxes must be 2-D",
@ -132,9 +136,13 @@ class CropAndResizeOp : public OpKernel {
CheckValidBoxInd<Device>(context, box_ind_data, batch);
functor::CropAndResize<Device, T>()(context->eigen_device<Device>(),
image_data, boxes_data, box_ind_data,
extrapolation_value_, crops_data);
bool status = functor::CropAndResize<Device, T>()(
context->eigen_device<Device>(), image_data, boxes_data, box_ind_data,
extrapolation_value_, crops_data);
if (!status) {
context->SetStatus(
errors::Internal("Failed launch CropAndResizeKernel."));
}
}
private:
@ -145,11 +153,12 @@ class CropAndResizeOp : public OpKernel {
namespace functor {
template <typename T>
struct CropAndResize<CPUDevice, T> {
void operator()(const CPUDevice& d, typename TTypes<T, 4>::ConstTensor image,
bool operator()(const CPUDevice& d, typename TTypes<T, 4>::ConstTensor image,
typename TTypes<float, 2>::ConstTensor boxes,
typename TTypes<int32, 1>::ConstTensor box_ind,
float extrapolation_value,
typename TTypes<float, 4>::Tensor crops) {
const int batch = image.dimension(0);
const int image_height = image.dimension(1);
const int image_width = image.dimension(2);
@ -163,7 +172,11 @@ struct CropAndResize<CPUDevice, T> {
const float x1 = boxes(b, 1);
const float y2 = boxes(b, 2);
const float x2 = boxes(b, 3);
const int32 b_in = box_ind(b);
if (b_in < 0 || b_in >= batch) {
continue;
}
const float height_scale =
(crop_height > 1) ? (y2 - y1) * (image_height - 1) / (crop_height - 1)
@ -217,6 +230,7 @@ struct CropAndResize<CPUDevice, T> {
}
}
}
return true;
}
};
} // namespace functor
@ -235,6 +249,7 @@ class CropAndResizeGradImageOp : public OpKernel {
void Compute(OpKernelContext* context) override {
// The shape of 'grads' is [num_boxes, crop_height, crop_width, depth].
const Tensor& grads = context->input(0);
OP_REQUIRES(context, grads.dims() == 4,
errors::InvalidArgument("grads image must be 4-D",
grads.shape().DebugString()));
@ -294,9 +309,13 @@ class CropAndResizeGradImageOp : public OpKernel {
CheckValidBoxInd<Device>(context, box_ind_data, batch);
functor::CropAndResizeBackpropImage<Device, T>()(
bool status = functor::CropAndResizeBackpropImage<Device, T>()(
context->eigen_device<Device>(), grads_data, boxes_data, box_ind_data,
output_data);
if (!status) {
context->SetStatus(
errors::Internal("Failed launch CropAndResizeBackpropImageKernel."));
}
}
};
@ -304,11 +323,12 @@ class CropAndResizeGradImageOp : public OpKernel {
namespace functor {
template <typename T>
struct CropAndResizeBackpropImage<CPUDevice, T> {
void operator()(const CPUDevice& d,
bool operator()(const CPUDevice& d,
typename TTypes<float, 4>::ConstTensor grads,
typename TTypes<float, 2>::ConstTensor boxes,
typename TTypes<int32, 1>::ConstTensor box_ind,
typename TTypes<T, 4>::Tensor grads_image) {
const int batch = grads_image.dimension(0);
const int image_height = grads_image.dimension(1);
const int image_width = grads_image.dimension(2);
@ -324,7 +344,11 @@ struct CropAndResizeBackpropImage<CPUDevice, T> {
const float x1 = boxes(b, 1);
const float y2 = boxes(b, 2);
const float x2 = boxes(b, 3);
const int32 b_in = box_ind(b);
if (b_in < 0 || b_in >= batch) {
continue;
}
const float height_scale =
(crop_height > 1) ? (y2 - y1) * (image_height - 1) / (crop_height - 1)
@ -370,6 +394,7 @@ struct CropAndResizeBackpropImage<CPUDevice, T> {
}
}
}
return true;
}
};
} // namespace functor
@ -388,6 +413,7 @@ class CropAndResizeGradBoxesOp : public OpKernel {
void Compute(OpKernelContext* context) override {
// The shape of 'grads' is [num_boxes, crop_height, crop_width, depth].
const Tensor& grads = context->input(0);
OP_REQUIRES(context, grads.dims() == 4,
errors::InvalidArgument("grads image must be 4-D",
grads.shape().DebugString()));
@ -441,9 +467,13 @@ class CropAndResizeGradBoxesOp : public OpKernel {
CheckValidBoxInd<Device>(context, box_ind_data, batch);
functor::CropAndResizeBackpropBoxes<Device, T>()(
bool status = functor::CropAndResizeBackpropBoxes<Device, T>()(
context->eigen_device<Device>(), grads_data, image_data, boxes_data,
box_ind_data, output_data);
if (!status) {
context->SetStatus(
errors::Internal("Failed launch CropAndResizeBackpropBoxesKernel."));
}
}
};
@ -451,12 +481,13 @@ class CropAndResizeGradBoxesOp : public OpKernel {
namespace functor {
template <typename T>
struct CropAndResizeBackpropBoxes<CPUDevice, T> {
void operator()(const CPUDevice& d,
bool operator()(const CPUDevice& d,
typename TTypes<float, 4>::ConstTensor grads,
typename TTypes<T, 4>::ConstTensor image,
typename TTypes<float, 2>::ConstTensor boxes,
typename TTypes<int32, 1>::ConstTensor box_ind,
typename TTypes<float, 2>::Tensor grads_boxes) {
const int batch = image.dimension(0);
const int image_height = image.dimension(1);
const int image_width = image.dimension(2);
@ -472,7 +503,11 @@ struct CropAndResizeBackpropBoxes<CPUDevice, T> {
const float x1 = boxes(b, 1);
const float y2 = boxes(b, 2);
const float x2 = boxes(b, 3);
const int32 b_in = box_ind(b);
if (b_in < 0 || b_in >= batch) {
continue;
}
const float height_ratio =
(crop_height > 1)
@ -547,6 +582,7 @@ struct CropAndResizeBackpropBoxes<CPUDevice, T> {
}
}
}
return true;
}
};
} // namespace functor
@ -563,37 +599,25 @@ inline void CheckValidBoxInd<CPUDevice>(
}
}
#define REGISTER_KERNEL(T) \
REGISTER_KERNEL_BUILDER(Name("CropAndResize") \
.Device(DEVICE_CPU) \
.TypeConstraint<T>("T") \
.HostMemory("crop_size"), \
CropAndResizeOp<CPUDevice, T>);
TF_CALL_REAL_NUMBER_TYPES(REGISTER_KERNEL);
#undef REGISTER_KERNEL
#define REGISTER_KERNEL(T) \
REGISTER_KERNEL_BUILDER(Name("CropAndResizeGradImage") \
.Device(DEVICE_CPU) \
.TypeConstraint<T>("T") \
.HostMemory("image_size"), \
CropAndResizeGradImageOp<CPUDevice, T>);
TF_CALL_half(REGISTER_KERNEL);
TF_CALL_float(REGISTER_KERNEL);
TF_CALL_double(REGISTER_KERNEL);
#undef REGISTER_KERNEL
#define REGISTER_KERNEL(T) \
REGISTER_KERNEL_BUILDER(Name("CropAndResizeGradBoxes") \
.Device(DEVICE_CPU) \
.TypeConstraint<T>("T"), \
#define REGISTER_KERNEL(T) \
REGISTER_KERNEL_BUILDER(Name("CropAndResize") \
.Device(DEVICE_CPU) \
.TypeConstraint<T>("T") \
.HostMemory("crop_size"), \
CropAndResizeOp<CPUDevice, T>); \
\
REGISTER_KERNEL_BUILDER(Name("CropAndResizeGradImage") \
.Device(DEVICE_CPU) \
.TypeConstraint<T>("T") \
.HostMemory("image_size"), \
CropAndResizeGradImageOp<CPUDevice, T>); \
\
REGISTER_KERNEL_BUILDER(Name("CropAndResizeGradBoxes") \
.Device(DEVICE_CPU) \
.TypeConstraint<T>("T"), \
CropAndResizeGradBoxesOp<CPUDevice, T>);
TF_CALL_REAL_NUMBER_TYPES(REGISTER_KERNEL);
TF_CALL_float(REGISTER_KERNEL);
#undef REGISTER_KERNEL
@ -613,6 +637,10 @@ template <>
inline void CheckValidBoxInd<GPUDevice>(
OpKernelContext* context, typename TTypes<int32, 1>::ConstTensor box_ind,
int batch) {
const int num_boxes = box_ind.dimension(0);
if (num_boxes == 0) {
return;
}
Tensor isvalid_tensor;
OP_REQUIRES_OK(context,
context->allocate_temp(DataTypeToEnum<bool>::value,
@ -657,7 +685,7 @@ inline void CheckValidBoxInd<GPUDevice>(
.TypeConstraint<T>("T"), \
CropAndResizeGradBoxesOp<GPUDevice, T>);
TF_CALL_GPU_NUMBER_TYPES_NO_HALF(REGISTER_KERNEL);
TF_CALL_float(REGISTER_KERNEL);
#undef REGISTER_KERNEL

View File

@ -26,7 +26,7 @@ namespace functor {
template <typename Device, typename T>
struct CropAndResize {
// We assume that the tensor sizes are correct.
void operator()(const Device& d, typename TTypes<T, 4>::ConstTensor image,
bool operator()(const Device& d, typename TTypes<T, 4>::ConstTensor image,
typename TTypes<float, 2>::ConstTensor boxes,
typename TTypes<int32, 1>::ConstTensor box_ind,
float extrapolation_value,
@ -36,7 +36,7 @@ struct CropAndResize {
template <typename Device, typename T>
struct CropAndResizeBackpropImage {
// We assume that the tensor sizes are correct.
void operator()(const Device& d, typename TTypes<float, 4>::ConstTensor grads,
bool operator()(const Device& d, typename TTypes<float, 4>::ConstTensor grads,
typename TTypes<float, 2>::ConstTensor boxes,
typename TTypes<int32, 1>::ConstTensor box_ind,
typename TTypes<T, 4>::Tensor grads_image);
@ -45,7 +45,7 @@ struct CropAndResizeBackpropImage {
template <typename Device, typename T>
struct CropAndResizeBackpropBoxes {
// We assume that the tensor sizes are correct.
void operator()(const Device& d, typename TTypes<float, 4>::ConstTensor grads,
bool operator()(const Device& d, typename TTypes<float, 4>::ConstTensor grads,
typename TTypes<T, 4>::ConstTensor image,
typename TTypes<float, 2>::ConstTensor boxes,
typename TTypes<int32, 1>::ConstTensor box_ind,

View File

@ -33,27 +33,30 @@ typedef Eigen::GpuDevice GPUDevice;
namespace {
template <typename T>
__global__ void CropAndResizeKernel(const int32 nthreads, const T* image_ptr,
const float* boxes_ptr,
const int32* box_ind_ptr, int num_boxes,
int image_height, int image_width,
int crop_height, int crop_width, int depth,
float extrapolation_value,
float* crops_ptr) {
__global__ void CropAndResizeKernel(
const int32 nthreads, const T* image_ptr, const float* boxes_ptr,
const int32* box_ind_ptr, int num_boxes, int batch, int image_height,
int image_width, int crop_height, int crop_width, int depth,
float extrapolation_value, float* crops_ptr) {
CUDA_1D_KERNEL_LOOP(out_idx, nthreads) {
// out_idx = d + depth * (w + crop_width * (h + crop_height * b))
const int d = out_idx % depth;
const int out_idx2 = out_idx / depth;
const int x = out_idx2 % crop_width;
const int out_idx3 = out_idx2 / crop_width;
const int y = out_idx3 % crop_height;
const int b = out_idx3 / crop_height;
int idx = out_idx;
const int d = idx % depth;
idx /= depth;
const int x = idx % crop_width;
idx /= crop_width;
const int y = idx % crop_height;
const int b = idx / crop_height;
const float y1 = boxes_ptr[b * 4];
const float x1 = boxes_ptr[b * 4 + 1];
const float y2 = boxes_ptr[b * 4 + 2];
const float x2 = boxes_ptr[b * 4 + 3];
const int32 b_in = box_ind_ptr[b];
if (b_in < 0 || b_in >= batch) {
continue;
}
const float height_scale =
(crop_height > 1) ? (y2 - y1) * (image_height - 1) / (crop_height - 1)
@ -66,7 +69,7 @@ __global__ void CropAndResizeKernel(const int32 nthreads, const T* image_ptr,
: 0.5 * (y1 + y2) * (image_height - 1);
if (in_y < 0 || in_y > image_height - 1) {
crops_ptr[out_idx] = extrapolation_value;
return;
continue;
}
const float in_x = (crop_width > 1)
@ -74,7 +77,7 @@ __global__ void CropAndResizeKernel(const int32 nthreads, const T* image_ptr,
: 0.5 * (x1 + x2) * (image_width - 1);
if (in_x < 0 || in_x > image_width - 1) {
crops_ptr[out_idx] = extrapolation_value;
return;
continue;
}
const int top_y_index = floorf(in_y);
@ -114,22 +117,28 @@ __global__ void CropAndResizeKernel(const int32 nthreads, const T* image_ptr,
template <typename T>
__global__ void CropAndResizeBackpropImageKernel(
const int32 nthreads, const float* grads_ptr, const float* boxes_ptr,
const int32* box_ind_ptr, int num_boxes, int image_height, int image_width,
int crop_height, int crop_width, int depth, T* grads_image_ptr) {
const int32* box_ind_ptr, int num_boxes, int batch, int image_height,
int image_width, int crop_height, int crop_width, int depth,
T* grads_image_ptr) {
CUDA_1D_KERNEL_LOOP(out_idx, nthreads) {
// out_idx = d + depth * (w + crop_width * (h + crop_height * b))
const int d = out_idx % depth;
const int out_idx2 = out_idx / depth;
const int x = out_idx2 % crop_width;
const int out_idx3 = out_idx2 / crop_width;
const int y = out_idx3 % crop_height;
const int b = out_idx3 / crop_height;
int idx = out_idx;
const int d = idx % depth;
idx /= depth;
const int x = idx % crop_width;
idx /= crop_width;
const int y = idx % crop_height;
const int b = idx / crop_height;
const float y1 = boxes_ptr[b * 4];
const float x1 = boxes_ptr[b * 4 + 1];
const float y2 = boxes_ptr[b * 4 + 2];
const float x2 = boxes_ptr[b * 4 + 3];
const int32 b_in = box_ind_ptr[b];
if (b_in < 0 || b_in >= batch) {
continue;
}
const float height_scale =
(crop_height > 1) ? (y2 - y1) * (image_height - 1) / (crop_height - 1)
@ -141,14 +150,14 @@ __global__ void CropAndResizeBackpropImageKernel(
? y1 * (image_height - 1) + y * height_scale
: 0.5 * (y1 + y2) * (image_height - 1);
if (in_y < 0 || in_y > image_height - 1) {
return;
continue;
}
const float in_x = (crop_width > 1)
? x1 * (image_width - 1) + x * width_scale
: 0.5 * (x1 + x2) * (image_width - 1);
if (in_x < 0 || in_x > image_width - 1) {
return;
continue;
}
const int top_y_index = floorf(in_y);
@ -192,23 +201,28 @@ __global__ void CropAndResizeBackpropImageKernel(
template <typename T>
__global__ void CropAndResizeBackpropBoxesKernel(
const int32 nthreads, const float* grads_ptr, const T* image_ptr,
const float* boxes_ptr, const int32* box_ind_ptr, int num_boxes,
const float* boxes_ptr, const int32* box_ind_ptr, int num_boxes, int batch,
int image_height, int image_width, int crop_height, int crop_width,
int depth, float* grads_boxes_ptr) {
CUDA_1D_KERNEL_LOOP(out_idx, nthreads) {
// out_idx = d + depth * (w + crop_width * (h + crop_height * b))
const int d = out_idx % depth;
const int out_idx2 = out_idx / depth;
const int x = out_idx2 % crop_width;
const int out_idx3 = out_idx2 / crop_width;
const int y = out_idx3 % crop_height;
const int b = out_idx3 / crop_height;
int idx = out_idx;
const int d = idx % depth;
idx /= depth;
const int x = idx % crop_width;
idx /= crop_width;
const int y = idx % crop_height;
const int b = idx / crop_height;
const float y1 = boxes_ptr[b * 4];
const float x1 = boxes_ptr[b * 4 + 1];
const float y2 = boxes_ptr[b * 4 + 2];
const float x2 = boxes_ptr[b * 4 + 3];
const int32 b_in = box_ind_ptr[b];
if (b_in < 0 || b_in >= batch) {
continue;
}
const float height_ratio =
(crop_height > 1)
@ -226,14 +240,14 @@ __global__ void CropAndResizeBackpropBoxesKernel(
? y1 * (image_height - 1) + y * height_scale
: 0.5 * (y1 + y2) * (image_height - 1);
if (in_y < 0 || in_y > image_height - 1) {
return;
continue;
}
const float in_x = (crop_width > 1)
? x1 * (image_width - 1) + x * width_scale
: 0.5 * (x1 + x2) * (image_width - 1);
if (in_x < 0 || in_x > image_width - 1) {
return;
continue;
}
const int top_y_index = floorf(in_y);
@ -306,11 +320,12 @@ namespace functor {
template <typename T>
struct CropAndResize<GPUDevice, T> {
void operator()(const GPUDevice& d, typename TTypes<T, 4>::ConstTensor image,
bool operator()(const GPUDevice& d, typename TTypes<T, 4>::ConstTensor image,
typename TTypes<float, 2>::ConstTensor boxes,
typename TTypes<int32, 1>::ConstTensor box_ind,
float extrapolation_value,
typename TTypes<float, 4>::Tensor crops) {
const int batch = image.dimension(0);
const int image_height = image.dimension(1);
const int image_width = image.dimension(2);
@ -320,19 +335,22 @@ struct CropAndResize<GPUDevice, T> {
const int depth = crops.dimension(3);
const int total_count = num_boxes * crop_height * crop_width * depth;
CudaLaunchConfig config = GetCudaLaunchConfig(total_count, d);
CropAndResizeKernel<<<config.block_count, config.thread_per_block, 0,
d.stream()>>>(
config.virtual_thread_count, image.data(), boxes.data(), box_ind.data(),
num_boxes, image_height, image_width, crop_height, crop_width, depth,
extrapolation_value, crops.data());
if (total_count > 0) {
CudaLaunchConfig config = GetCudaLaunchConfig(total_count, d);
CropAndResizeKernel<<<config.block_count, config.thread_per_block, 0,
d.stream()>>>(
config.virtual_thread_count, image.data(), boxes.data(),
box_ind.data(), num_boxes, batch, image_height, image_width,
crop_height, crop_width, depth, extrapolation_value, crops.data());
}
return d.ok();
}
};
template <typename T>
struct CropAndResizeBackpropImage<GPUDevice, T> {
void operator()(const GPUDevice& d,
bool operator()(const GPUDevice& d,
typename TTypes<float, 4>::ConstTensor grads,
typename TTypes<float, 2>::ConstTensor boxes,
typename TTypes<int32, 1>::ConstTensor box_ind,
@ -351,29 +369,35 @@ struct CropAndResizeBackpropImage<GPUDevice, T> {
// Initialize grads_image with all zeros.
total_count = batch * image_height * image_width * depth;
config = GetCudaLaunchConfig(total_count, d);
SetZero<<<config.block_count, config.thread_per_block, 0, d.stream()>>>(
total_count, grads_image.data());
if (total_count > 0) {
config = GetCudaLaunchConfig(total_count, d);
SetZero<<<config.block_count, config.thread_per_block, 0, d.stream()>>>(
config.virtual_thread_count, grads_image.data());
}
// Accumulate.
total_count = num_boxes * crop_height * crop_width * depth;
config = GetCudaLaunchConfig(total_count, d);
CropAndResizeBackpropImageKernel<<<
config.block_count, config.thread_per_block, 0, d.stream()>>>(
config.virtual_thread_count, grads.data(), boxes.data(), box_ind.data(),
num_boxes, image_height, image_width, crop_height, crop_width, depth,
grads_image.data());
if (total_count > 0) {
config = GetCudaLaunchConfig(total_count, d);
CropAndResizeBackpropImageKernel<<<
config.block_count, config.thread_per_block, 0, d.stream()>>>(
config.virtual_thread_count, grads.data(), boxes.data(),
box_ind.data(), num_boxes, batch, image_height, image_width,
crop_height, crop_width, depth, grads_image.data());
}
return d.ok();
}
};
template <typename T>
struct CropAndResizeBackpropBoxes<GPUDevice, T> {
void operator()(const GPUDevice& d,
bool operator()(const GPUDevice& d,
typename TTypes<float, 4>::ConstTensor grads,
typename TTypes<T, 4>::ConstTensor image,
typename TTypes<float, 2>::ConstTensor boxes,
typename TTypes<int32, 1>::ConstTensor box_ind,
typename TTypes<float, 2>::Tensor grads_boxes) {
const int batch = image.dimension(0);
const int image_height = image.dimension(1);
const int image_width = image.dimension(2);
@ -387,18 +411,23 @@ struct CropAndResizeBackpropBoxes<GPUDevice, T> {
// Initialize grads_boxes with all zeros.
total_count = num_boxes * 4;
config = GetCudaLaunchConfig(total_count, d);
SetZero<<<config.block_count, config.thread_per_block, 0, d.stream()>>>(
total_count, grads_boxes.data());
if (total_count > 0) {
config = GetCudaLaunchConfig(total_count, d);
SetZero<<<config.block_count, config.thread_per_block, 0, d.stream()>>>(
config.virtual_thread_count, grads_boxes.data());
}
// Accumulate.
total_count = num_boxes * crop_height * crop_width * depth;
config = GetCudaLaunchConfig(total_count, d);
CropAndResizeBackpropBoxesKernel<<<
config.block_count, config.thread_per_block, 0, d.stream()>>>(
config.virtual_thread_count, grads.data(), image.data(), boxes.data(),
box_ind.data(), num_boxes, image_height, image_width, crop_height,
crop_width, depth, grads_boxes.data());
if (total_count > 0) {
config = GetCudaLaunchConfig(total_count, d);
CropAndResizeBackpropBoxesKernel<<<
config.block_count, config.thread_per_block, 0, d.stream()>>>(
config.virtual_thread_count, grads.data(), image.data(), boxes.data(),
box_ind.data(), num_boxes, batch, image_height, image_width,
crop_height, crop_width, depth, grads_boxes.data());
}
return d.ok();
}
};
@ -407,7 +436,7 @@ struct CropAndResizeBackpropBoxes<GPUDevice, T> {
template struct CropAndResizeBackpropImage<GPUDevice, T>; \
template struct CropAndResizeBackpropBoxes<GPUDevice, T>;
TF_CALL_GPU_NUMBER_TYPES_NO_HALF(DEFINE_GPU_SPECS);
TF_CALL_float(DEFINE_GPU_SPECS);
#undef DEFINE_GPU_SPECS

View File

@ -189,6 +189,24 @@ TEST_F(CropAndResizeOpTest, TestCropAndResize2x2To3x3Extrapolated) {
test::ExpectTensorEqual<float>(expected, *GetOutput(0));
}
TEST_F(CropAndResizeOpTest, TestCropAndResize2x2To3x3NoCrop) {
MakeOp(0);
// Input:
// 1, 2
// 3, 4
AddInputFromArray<float>(TensorShape({1, 2, 2, 1}), {1, 2, 3, 4});
AddInputFromArray<float>(TensorShape({0, 4}), {});
AddInputFromArray<int32>(TensorShape({0}), {});
AddInputFromArray<int32>(TensorShape({2}), {3, 3});
TF_ASSERT_OK(RunOpKernel());
Tensor expected(allocator(), DT_FLOAT, TensorShape({0, 3, 3, 1}));
// clang-format off
test::FillValues<float>(&expected, {});
// clang-format on
test::ExpectTensorEqual<float>(expected, *GetOutput(0));
}
TEST_F(CropAndResizeOpTest, TestInvalidInputShape) {
MakeOp(0);
AddInputFromArray<float>(TensorShape({2, 2, 1}), {1, 2, 3, 4});
@ -201,6 +219,19 @@ TEST_F(CropAndResizeOpTest, TestInvalidInputShape) {
<< s;
}
TEST_F(CropAndResizeOpTest, TestInvalidBoxIndexShape) {
MakeOp(0);
AddInputFromArray<float>(TensorShape({1, 2, 2, 1}), {1, 2, 3, 4});
AddInputFromArray<float>(TensorShape({1, 4}), {0, 0, 1, 1});
AddInputFromArray<int32>(TensorShape({2}), {0, 0});
AddInputFromArray<int32>(TensorShape({2}), {4, 4});
Status s = RunOpKernel();
ASSERT_FALSE(s.ok());
EXPECT_TRUE(
StringPiece(s.ToString()).contains("box_ind has incompatible shape"))
<< s;
}
TEST_F(CropAndResizeOpTest, TestInvalidBoxIndex) {
MakeOp(0);
AddInputFromArray<float>(TensorShape({1, 2, 2, 1}), {1, 2, 3, 4});

View File

@ -253,54 +253,53 @@ class CropAndResizeOpTest(tf.test.TestCase):
radius = 2 * delta
low, high = -0.5, 1.5 # Also covers the case of extrapolation.
for image_height in range(1, 5):
for image_width in range(1, 3):
for crop_height in range(1, 3):
for crop_width in range(2, 4):
for depth in range(1, 3):
for num_boxes in range(1, 3):
image_height = 4
for image_width in range(1, 3):
for crop_height in range(1, 3):
for crop_width in range(2, 4):
for depth in range(1, 3):
for num_boxes in range(1, 3):
batch = num_boxes
image_shape = [batch, image_height, image_width, depth]
crop_size = [crop_height, crop_width]
crops_shape = [num_boxes, crop_height, crop_width, depth]
boxes_shape = [num_boxes, 4]
batch = num_boxes
image_shape = [batch, image_height, image_width, depth]
crop_size = [crop_height, crop_width]
crops_shape = [num_boxes, crop_height, crop_width, depth]
boxes_shape = [num_boxes, 4]
image = np.arange(0, batch * image_height * image_width *
depth).reshape(image_shape).astype(np.float32)
boxes = []
for _ in range(num_boxes):
# pylint: disable=unbalanced-tuple-unpacking
y1, y2 = self._randomUniformAvoidAnchors(
low, high, np.linspace(0, 1, image_height), radius, 2)
x1, x2 = self._randomUniformAvoidAnchors(
low, high, np.linspace(0, 1, image_width), radius, 2)
# pylint: enable=unbalanced-tuple-unpacking
boxes.append([y1, x1, y2, x2])
image = np.arange(0, batch * image_height * image_width *
depth).reshape(image_shape).astype(np.float32)
boxes = []
for _ in range(num_boxes):
# pylint: disable=unbalanced-tuple-unpacking
y1, y2 = self._randomUniformAvoidAnchors(
low, high, np.linspace(0, 1, image_height), radius, 2)
x1, x2 = self._randomUniformAvoidAnchors(
low, high, np.linspace(0, 1, image_width), radius, 2)
# pylint: enable=unbalanced-tuple-unpacking
boxes.append([y1, x1, y2, x2])
boxes = np.array(boxes, dtype=np.float32)
box_ind = np.arange(batch, dtype=np.int32)
boxes = np.array(boxes, dtype=np.float32)
box_ind = np.arange(batch, dtype=np.int32)
for use_gpu in [False, True]:
with self.test_session(use_gpu=use_gpu):
image_tensor = tf.constant(image, shape=image_shape)
boxes_tensor = tf.constant(boxes, shape=[num_boxes, 4])
box_ind_tensor = tf.constant(box_ind, shape=[num_boxes])
crops = tf.image.crop_and_resize(
image_tensor,
boxes_tensor,
box_ind_tensor,
tf.constant(crop_size, shape=[2]))
for use_gpu in [False, True]:
with self.test_session(use_gpu=use_gpu):
image_tensor = tf.constant(image, shape=image_shape)
boxes_tensor = tf.constant(boxes, shape=[num_boxes, 4])
box_ind_tensor = tf.constant(box_ind, shape=[num_boxes])
crops = tf.image.crop_and_resize(
image_tensor,
boxes_tensor,
box_ind_tensor,
tf.constant(crop_size, shape=[2]))
err = tf.test.compute_gradient_error(
[image_tensor, boxes_tensor],
[image_shape, boxes_shape],
crops,
crops_shape,
delta=delta,
x_init_value=[image, boxes])
err = tf.test.compute_gradient_error(
[image_tensor, boxes_tensor], [image_shape, boxes_shape],
crops,
crops_shape,
delta=delta,
x_init_value=[image, boxes])
self.assertLess(err, 2e-3)
self.assertLess(err, 2e-3)
if __name__ == "__main__":