Correct bug in crop_and_resize which caused failures to some tests.
Change: 126246458
This commit is contained in:
parent
e8974bac93
commit
1d92cfcbf5
@ -42,6 +42,10 @@ static inline void ParseAndCheckBoxSizes(OpKernelContext* context,
|
||||
const Tensor& boxes,
|
||||
const Tensor& box_ind,
|
||||
int* num_boxes) {
|
||||
if (boxes.NumElements() == 0 && box_ind.NumElements() == 0) {
|
||||
*num_boxes = 0;
|
||||
return;
|
||||
}
|
||||
// The shape of 'boxes' is [num_boxes, 4].
|
||||
OP_REQUIRES(context, boxes.dims() == 2,
|
||||
errors::InvalidArgument("boxes must be 2-D",
|
||||
@ -132,9 +136,13 @@ class CropAndResizeOp : public OpKernel {
|
||||
|
||||
CheckValidBoxInd<Device>(context, box_ind_data, batch);
|
||||
|
||||
functor::CropAndResize<Device, T>()(context->eigen_device<Device>(),
|
||||
image_data, boxes_data, box_ind_data,
|
||||
extrapolation_value_, crops_data);
|
||||
bool status = functor::CropAndResize<Device, T>()(
|
||||
context->eigen_device<Device>(), image_data, boxes_data, box_ind_data,
|
||||
extrapolation_value_, crops_data);
|
||||
if (!status) {
|
||||
context->SetStatus(
|
||||
errors::Internal("Failed launch CropAndResizeKernel."));
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
@ -145,11 +153,12 @@ class CropAndResizeOp : public OpKernel {
|
||||
namespace functor {
|
||||
template <typename T>
|
||||
struct CropAndResize<CPUDevice, T> {
|
||||
void operator()(const CPUDevice& d, typename TTypes<T, 4>::ConstTensor image,
|
||||
bool operator()(const CPUDevice& d, typename TTypes<T, 4>::ConstTensor image,
|
||||
typename TTypes<float, 2>::ConstTensor boxes,
|
||||
typename TTypes<int32, 1>::ConstTensor box_ind,
|
||||
float extrapolation_value,
|
||||
typename TTypes<float, 4>::Tensor crops) {
|
||||
const int batch = image.dimension(0);
|
||||
const int image_height = image.dimension(1);
|
||||
const int image_width = image.dimension(2);
|
||||
|
||||
@ -163,7 +172,11 @@ struct CropAndResize<CPUDevice, T> {
|
||||
const float x1 = boxes(b, 1);
|
||||
const float y2 = boxes(b, 2);
|
||||
const float x2 = boxes(b, 3);
|
||||
|
||||
const int32 b_in = box_ind(b);
|
||||
if (b_in < 0 || b_in >= batch) {
|
||||
continue;
|
||||
}
|
||||
|
||||
const float height_scale =
|
||||
(crop_height > 1) ? (y2 - y1) * (image_height - 1) / (crop_height - 1)
|
||||
@ -217,6 +230,7 @@ struct CropAndResize<CPUDevice, T> {
|
||||
}
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
};
|
||||
} // namespace functor
|
||||
@ -235,6 +249,7 @@ class CropAndResizeGradImageOp : public OpKernel {
|
||||
void Compute(OpKernelContext* context) override {
|
||||
// The shape of 'grads' is [num_boxes, crop_height, crop_width, depth].
|
||||
const Tensor& grads = context->input(0);
|
||||
|
||||
OP_REQUIRES(context, grads.dims() == 4,
|
||||
errors::InvalidArgument("grads image must be 4-D",
|
||||
grads.shape().DebugString()));
|
||||
@ -294,9 +309,13 @@ class CropAndResizeGradImageOp : public OpKernel {
|
||||
|
||||
CheckValidBoxInd<Device>(context, box_ind_data, batch);
|
||||
|
||||
functor::CropAndResizeBackpropImage<Device, T>()(
|
||||
bool status = functor::CropAndResizeBackpropImage<Device, T>()(
|
||||
context->eigen_device<Device>(), grads_data, boxes_data, box_ind_data,
|
||||
output_data);
|
||||
if (!status) {
|
||||
context->SetStatus(
|
||||
errors::Internal("Failed launch CropAndResizeBackpropImageKernel."));
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
@ -304,11 +323,12 @@ class CropAndResizeGradImageOp : public OpKernel {
|
||||
namespace functor {
|
||||
template <typename T>
|
||||
struct CropAndResizeBackpropImage<CPUDevice, T> {
|
||||
void operator()(const CPUDevice& d,
|
||||
bool operator()(const CPUDevice& d,
|
||||
typename TTypes<float, 4>::ConstTensor grads,
|
||||
typename TTypes<float, 2>::ConstTensor boxes,
|
||||
typename TTypes<int32, 1>::ConstTensor box_ind,
|
||||
typename TTypes<T, 4>::Tensor grads_image) {
|
||||
const int batch = grads_image.dimension(0);
|
||||
const int image_height = grads_image.dimension(1);
|
||||
const int image_width = grads_image.dimension(2);
|
||||
|
||||
@ -324,7 +344,11 @@ struct CropAndResizeBackpropImage<CPUDevice, T> {
|
||||
const float x1 = boxes(b, 1);
|
||||
const float y2 = boxes(b, 2);
|
||||
const float x2 = boxes(b, 3);
|
||||
|
||||
const int32 b_in = box_ind(b);
|
||||
if (b_in < 0 || b_in >= batch) {
|
||||
continue;
|
||||
}
|
||||
|
||||
const float height_scale =
|
||||
(crop_height > 1) ? (y2 - y1) * (image_height - 1) / (crop_height - 1)
|
||||
@ -370,6 +394,7 @@ struct CropAndResizeBackpropImage<CPUDevice, T> {
|
||||
}
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
};
|
||||
} // namespace functor
|
||||
@ -388,6 +413,7 @@ class CropAndResizeGradBoxesOp : public OpKernel {
|
||||
void Compute(OpKernelContext* context) override {
|
||||
// The shape of 'grads' is [num_boxes, crop_height, crop_width, depth].
|
||||
const Tensor& grads = context->input(0);
|
||||
|
||||
OP_REQUIRES(context, grads.dims() == 4,
|
||||
errors::InvalidArgument("grads image must be 4-D",
|
||||
grads.shape().DebugString()));
|
||||
@ -441,9 +467,13 @@ class CropAndResizeGradBoxesOp : public OpKernel {
|
||||
|
||||
CheckValidBoxInd<Device>(context, box_ind_data, batch);
|
||||
|
||||
functor::CropAndResizeBackpropBoxes<Device, T>()(
|
||||
bool status = functor::CropAndResizeBackpropBoxes<Device, T>()(
|
||||
context->eigen_device<Device>(), grads_data, image_data, boxes_data,
|
||||
box_ind_data, output_data);
|
||||
if (!status) {
|
||||
context->SetStatus(
|
||||
errors::Internal("Failed launch CropAndResizeBackpropBoxesKernel."));
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
@ -451,12 +481,13 @@ class CropAndResizeGradBoxesOp : public OpKernel {
|
||||
namespace functor {
|
||||
template <typename T>
|
||||
struct CropAndResizeBackpropBoxes<CPUDevice, T> {
|
||||
void operator()(const CPUDevice& d,
|
||||
bool operator()(const CPUDevice& d,
|
||||
typename TTypes<float, 4>::ConstTensor grads,
|
||||
typename TTypes<T, 4>::ConstTensor image,
|
||||
typename TTypes<float, 2>::ConstTensor boxes,
|
||||
typename TTypes<int32, 1>::ConstTensor box_ind,
|
||||
typename TTypes<float, 2>::Tensor grads_boxes) {
|
||||
const int batch = image.dimension(0);
|
||||
const int image_height = image.dimension(1);
|
||||
const int image_width = image.dimension(2);
|
||||
|
||||
@ -472,7 +503,11 @@ struct CropAndResizeBackpropBoxes<CPUDevice, T> {
|
||||
const float x1 = boxes(b, 1);
|
||||
const float y2 = boxes(b, 2);
|
||||
const float x2 = boxes(b, 3);
|
||||
|
||||
const int32 b_in = box_ind(b);
|
||||
if (b_in < 0 || b_in >= batch) {
|
||||
continue;
|
||||
}
|
||||
|
||||
const float height_ratio =
|
||||
(crop_height > 1)
|
||||
@ -547,6 +582,7 @@ struct CropAndResizeBackpropBoxes<CPUDevice, T> {
|
||||
}
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
};
|
||||
} // namespace functor
|
||||
@ -563,37 +599,25 @@ inline void CheckValidBoxInd<CPUDevice>(
|
||||
}
|
||||
}
|
||||
|
||||
#define REGISTER_KERNEL(T) \
|
||||
REGISTER_KERNEL_BUILDER(Name("CropAndResize") \
|
||||
.Device(DEVICE_CPU) \
|
||||
.TypeConstraint<T>("T") \
|
||||
.HostMemory("crop_size"), \
|
||||
CropAndResizeOp<CPUDevice, T>);
|
||||
|
||||
TF_CALL_REAL_NUMBER_TYPES(REGISTER_KERNEL);
|
||||
|
||||
#undef REGISTER_KERNEL
|
||||
|
||||
#define REGISTER_KERNEL(T) \
|
||||
REGISTER_KERNEL_BUILDER(Name("CropAndResizeGradImage") \
|
||||
.Device(DEVICE_CPU) \
|
||||
.TypeConstraint<T>("T") \
|
||||
.HostMemory("image_size"), \
|
||||
CropAndResizeGradImageOp<CPUDevice, T>);
|
||||
|
||||
TF_CALL_half(REGISTER_KERNEL);
|
||||
TF_CALL_float(REGISTER_KERNEL);
|
||||
TF_CALL_double(REGISTER_KERNEL);
|
||||
|
||||
#undef REGISTER_KERNEL
|
||||
|
||||
#define REGISTER_KERNEL(T) \
|
||||
REGISTER_KERNEL_BUILDER(Name("CropAndResizeGradBoxes") \
|
||||
.Device(DEVICE_CPU) \
|
||||
.TypeConstraint<T>("T"), \
|
||||
#define REGISTER_KERNEL(T) \
|
||||
REGISTER_KERNEL_BUILDER(Name("CropAndResize") \
|
||||
.Device(DEVICE_CPU) \
|
||||
.TypeConstraint<T>("T") \
|
||||
.HostMemory("crop_size"), \
|
||||
CropAndResizeOp<CPUDevice, T>); \
|
||||
\
|
||||
REGISTER_KERNEL_BUILDER(Name("CropAndResizeGradImage") \
|
||||
.Device(DEVICE_CPU) \
|
||||
.TypeConstraint<T>("T") \
|
||||
.HostMemory("image_size"), \
|
||||
CropAndResizeGradImageOp<CPUDevice, T>); \
|
||||
\
|
||||
REGISTER_KERNEL_BUILDER(Name("CropAndResizeGradBoxes") \
|
||||
.Device(DEVICE_CPU) \
|
||||
.TypeConstraint<T>("T"), \
|
||||
CropAndResizeGradBoxesOp<CPUDevice, T>);
|
||||
|
||||
TF_CALL_REAL_NUMBER_TYPES(REGISTER_KERNEL);
|
||||
TF_CALL_float(REGISTER_KERNEL);
|
||||
|
||||
#undef REGISTER_KERNEL
|
||||
|
||||
@ -613,6 +637,10 @@ template <>
|
||||
inline void CheckValidBoxInd<GPUDevice>(
|
||||
OpKernelContext* context, typename TTypes<int32, 1>::ConstTensor box_ind,
|
||||
int batch) {
|
||||
const int num_boxes = box_ind.dimension(0);
|
||||
if (num_boxes == 0) {
|
||||
return;
|
||||
}
|
||||
Tensor isvalid_tensor;
|
||||
OP_REQUIRES_OK(context,
|
||||
context->allocate_temp(DataTypeToEnum<bool>::value,
|
||||
@ -657,7 +685,7 @@ inline void CheckValidBoxInd<GPUDevice>(
|
||||
.TypeConstraint<T>("T"), \
|
||||
CropAndResizeGradBoxesOp<GPUDevice, T>);
|
||||
|
||||
TF_CALL_GPU_NUMBER_TYPES_NO_HALF(REGISTER_KERNEL);
|
||||
TF_CALL_float(REGISTER_KERNEL);
|
||||
|
||||
#undef REGISTER_KERNEL
|
||||
|
||||
|
@ -26,7 +26,7 @@ namespace functor {
|
||||
template <typename Device, typename T>
|
||||
struct CropAndResize {
|
||||
// We assume that the tensor sizes are correct.
|
||||
void operator()(const Device& d, typename TTypes<T, 4>::ConstTensor image,
|
||||
bool operator()(const Device& d, typename TTypes<T, 4>::ConstTensor image,
|
||||
typename TTypes<float, 2>::ConstTensor boxes,
|
||||
typename TTypes<int32, 1>::ConstTensor box_ind,
|
||||
float extrapolation_value,
|
||||
@ -36,7 +36,7 @@ struct CropAndResize {
|
||||
template <typename Device, typename T>
|
||||
struct CropAndResizeBackpropImage {
|
||||
// We assume that the tensor sizes are correct.
|
||||
void operator()(const Device& d, typename TTypes<float, 4>::ConstTensor grads,
|
||||
bool operator()(const Device& d, typename TTypes<float, 4>::ConstTensor grads,
|
||||
typename TTypes<float, 2>::ConstTensor boxes,
|
||||
typename TTypes<int32, 1>::ConstTensor box_ind,
|
||||
typename TTypes<T, 4>::Tensor grads_image);
|
||||
@ -45,7 +45,7 @@ struct CropAndResizeBackpropImage {
|
||||
template <typename Device, typename T>
|
||||
struct CropAndResizeBackpropBoxes {
|
||||
// We assume that the tensor sizes are correct.
|
||||
void operator()(const Device& d, typename TTypes<float, 4>::ConstTensor grads,
|
||||
bool operator()(const Device& d, typename TTypes<float, 4>::ConstTensor grads,
|
||||
typename TTypes<T, 4>::ConstTensor image,
|
||||
typename TTypes<float, 2>::ConstTensor boxes,
|
||||
typename TTypes<int32, 1>::ConstTensor box_ind,
|
||||
|
@ -33,27 +33,30 @@ typedef Eigen::GpuDevice GPUDevice;
|
||||
namespace {
|
||||
|
||||
template <typename T>
|
||||
__global__ void CropAndResizeKernel(const int32 nthreads, const T* image_ptr,
|
||||
const float* boxes_ptr,
|
||||
const int32* box_ind_ptr, int num_boxes,
|
||||
int image_height, int image_width,
|
||||
int crop_height, int crop_width, int depth,
|
||||
float extrapolation_value,
|
||||
float* crops_ptr) {
|
||||
__global__ void CropAndResizeKernel(
|
||||
const int32 nthreads, const T* image_ptr, const float* boxes_ptr,
|
||||
const int32* box_ind_ptr, int num_boxes, int batch, int image_height,
|
||||
int image_width, int crop_height, int crop_width, int depth,
|
||||
float extrapolation_value, float* crops_ptr) {
|
||||
CUDA_1D_KERNEL_LOOP(out_idx, nthreads) {
|
||||
// out_idx = d + depth * (w + crop_width * (h + crop_height * b))
|
||||
const int d = out_idx % depth;
|
||||
const int out_idx2 = out_idx / depth;
|
||||
const int x = out_idx2 % crop_width;
|
||||
const int out_idx3 = out_idx2 / crop_width;
|
||||
const int y = out_idx3 % crop_height;
|
||||
const int b = out_idx3 / crop_height;
|
||||
int idx = out_idx;
|
||||
const int d = idx % depth;
|
||||
idx /= depth;
|
||||
const int x = idx % crop_width;
|
||||
idx /= crop_width;
|
||||
const int y = idx % crop_height;
|
||||
const int b = idx / crop_height;
|
||||
|
||||
const float y1 = boxes_ptr[b * 4];
|
||||
const float x1 = boxes_ptr[b * 4 + 1];
|
||||
const float y2 = boxes_ptr[b * 4 + 2];
|
||||
const float x2 = boxes_ptr[b * 4 + 3];
|
||||
|
||||
const int32 b_in = box_ind_ptr[b];
|
||||
if (b_in < 0 || b_in >= batch) {
|
||||
continue;
|
||||
}
|
||||
|
||||
const float height_scale =
|
||||
(crop_height > 1) ? (y2 - y1) * (image_height - 1) / (crop_height - 1)
|
||||
@ -66,7 +69,7 @@ __global__ void CropAndResizeKernel(const int32 nthreads, const T* image_ptr,
|
||||
: 0.5 * (y1 + y2) * (image_height - 1);
|
||||
if (in_y < 0 || in_y > image_height - 1) {
|
||||
crops_ptr[out_idx] = extrapolation_value;
|
||||
return;
|
||||
continue;
|
||||
}
|
||||
|
||||
const float in_x = (crop_width > 1)
|
||||
@ -74,7 +77,7 @@ __global__ void CropAndResizeKernel(const int32 nthreads, const T* image_ptr,
|
||||
: 0.5 * (x1 + x2) * (image_width - 1);
|
||||
if (in_x < 0 || in_x > image_width - 1) {
|
||||
crops_ptr[out_idx] = extrapolation_value;
|
||||
return;
|
||||
continue;
|
||||
}
|
||||
|
||||
const int top_y_index = floorf(in_y);
|
||||
@ -114,22 +117,28 @@ __global__ void CropAndResizeKernel(const int32 nthreads, const T* image_ptr,
|
||||
template <typename T>
|
||||
__global__ void CropAndResizeBackpropImageKernel(
|
||||
const int32 nthreads, const float* grads_ptr, const float* boxes_ptr,
|
||||
const int32* box_ind_ptr, int num_boxes, int image_height, int image_width,
|
||||
int crop_height, int crop_width, int depth, T* grads_image_ptr) {
|
||||
const int32* box_ind_ptr, int num_boxes, int batch, int image_height,
|
||||
int image_width, int crop_height, int crop_width, int depth,
|
||||
T* grads_image_ptr) {
|
||||
CUDA_1D_KERNEL_LOOP(out_idx, nthreads) {
|
||||
// out_idx = d + depth * (w + crop_width * (h + crop_height * b))
|
||||
const int d = out_idx % depth;
|
||||
const int out_idx2 = out_idx / depth;
|
||||
const int x = out_idx2 % crop_width;
|
||||
const int out_idx3 = out_idx2 / crop_width;
|
||||
const int y = out_idx3 % crop_height;
|
||||
const int b = out_idx3 / crop_height;
|
||||
int idx = out_idx;
|
||||
const int d = idx % depth;
|
||||
idx /= depth;
|
||||
const int x = idx % crop_width;
|
||||
idx /= crop_width;
|
||||
const int y = idx % crop_height;
|
||||
const int b = idx / crop_height;
|
||||
|
||||
const float y1 = boxes_ptr[b * 4];
|
||||
const float x1 = boxes_ptr[b * 4 + 1];
|
||||
const float y2 = boxes_ptr[b * 4 + 2];
|
||||
const float x2 = boxes_ptr[b * 4 + 3];
|
||||
|
||||
const int32 b_in = box_ind_ptr[b];
|
||||
if (b_in < 0 || b_in >= batch) {
|
||||
continue;
|
||||
}
|
||||
|
||||
const float height_scale =
|
||||
(crop_height > 1) ? (y2 - y1) * (image_height - 1) / (crop_height - 1)
|
||||
@ -141,14 +150,14 @@ __global__ void CropAndResizeBackpropImageKernel(
|
||||
? y1 * (image_height - 1) + y * height_scale
|
||||
: 0.5 * (y1 + y2) * (image_height - 1);
|
||||
if (in_y < 0 || in_y > image_height - 1) {
|
||||
return;
|
||||
continue;
|
||||
}
|
||||
|
||||
const float in_x = (crop_width > 1)
|
||||
? x1 * (image_width - 1) + x * width_scale
|
||||
: 0.5 * (x1 + x2) * (image_width - 1);
|
||||
if (in_x < 0 || in_x > image_width - 1) {
|
||||
return;
|
||||
continue;
|
||||
}
|
||||
|
||||
const int top_y_index = floorf(in_y);
|
||||
@ -192,23 +201,28 @@ __global__ void CropAndResizeBackpropImageKernel(
|
||||
template <typename T>
|
||||
__global__ void CropAndResizeBackpropBoxesKernel(
|
||||
const int32 nthreads, const float* grads_ptr, const T* image_ptr,
|
||||
const float* boxes_ptr, const int32* box_ind_ptr, int num_boxes,
|
||||
const float* boxes_ptr, const int32* box_ind_ptr, int num_boxes, int batch,
|
||||
int image_height, int image_width, int crop_height, int crop_width,
|
||||
int depth, float* grads_boxes_ptr) {
|
||||
CUDA_1D_KERNEL_LOOP(out_idx, nthreads) {
|
||||
// out_idx = d + depth * (w + crop_width * (h + crop_height * b))
|
||||
const int d = out_idx % depth;
|
||||
const int out_idx2 = out_idx / depth;
|
||||
const int x = out_idx2 % crop_width;
|
||||
const int out_idx3 = out_idx2 / crop_width;
|
||||
const int y = out_idx3 % crop_height;
|
||||
const int b = out_idx3 / crop_height;
|
||||
int idx = out_idx;
|
||||
const int d = idx % depth;
|
||||
idx /= depth;
|
||||
const int x = idx % crop_width;
|
||||
idx /= crop_width;
|
||||
const int y = idx % crop_height;
|
||||
const int b = idx / crop_height;
|
||||
|
||||
const float y1 = boxes_ptr[b * 4];
|
||||
const float x1 = boxes_ptr[b * 4 + 1];
|
||||
const float y2 = boxes_ptr[b * 4 + 2];
|
||||
const float x2 = boxes_ptr[b * 4 + 3];
|
||||
|
||||
const int32 b_in = box_ind_ptr[b];
|
||||
if (b_in < 0 || b_in >= batch) {
|
||||
continue;
|
||||
}
|
||||
|
||||
const float height_ratio =
|
||||
(crop_height > 1)
|
||||
@ -226,14 +240,14 @@ __global__ void CropAndResizeBackpropBoxesKernel(
|
||||
? y1 * (image_height - 1) + y * height_scale
|
||||
: 0.5 * (y1 + y2) * (image_height - 1);
|
||||
if (in_y < 0 || in_y > image_height - 1) {
|
||||
return;
|
||||
continue;
|
||||
}
|
||||
|
||||
const float in_x = (crop_width > 1)
|
||||
? x1 * (image_width - 1) + x * width_scale
|
||||
: 0.5 * (x1 + x2) * (image_width - 1);
|
||||
if (in_x < 0 || in_x > image_width - 1) {
|
||||
return;
|
||||
continue;
|
||||
}
|
||||
|
||||
const int top_y_index = floorf(in_y);
|
||||
@ -306,11 +320,12 @@ namespace functor {
|
||||
|
||||
template <typename T>
|
||||
struct CropAndResize<GPUDevice, T> {
|
||||
void operator()(const GPUDevice& d, typename TTypes<T, 4>::ConstTensor image,
|
||||
bool operator()(const GPUDevice& d, typename TTypes<T, 4>::ConstTensor image,
|
||||
typename TTypes<float, 2>::ConstTensor boxes,
|
||||
typename TTypes<int32, 1>::ConstTensor box_ind,
|
||||
float extrapolation_value,
|
||||
typename TTypes<float, 4>::Tensor crops) {
|
||||
const int batch = image.dimension(0);
|
||||
const int image_height = image.dimension(1);
|
||||
const int image_width = image.dimension(2);
|
||||
|
||||
@ -320,19 +335,22 @@ struct CropAndResize<GPUDevice, T> {
|
||||
const int depth = crops.dimension(3);
|
||||
|
||||
const int total_count = num_boxes * crop_height * crop_width * depth;
|
||||
CudaLaunchConfig config = GetCudaLaunchConfig(total_count, d);
|
||||
|
||||
CropAndResizeKernel<<<config.block_count, config.thread_per_block, 0,
|
||||
d.stream()>>>(
|
||||
config.virtual_thread_count, image.data(), boxes.data(), box_ind.data(),
|
||||
num_boxes, image_height, image_width, crop_height, crop_width, depth,
|
||||
extrapolation_value, crops.data());
|
||||
if (total_count > 0) {
|
||||
CudaLaunchConfig config = GetCudaLaunchConfig(total_count, d);
|
||||
CropAndResizeKernel<<<config.block_count, config.thread_per_block, 0,
|
||||
d.stream()>>>(
|
||||
config.virtual_thread_count, image.data(), boxes.data(),
|
||||
box_ind.data(), num_boxes, batch, image_height, image_width,
|
||||
crop_height, crop_width, depth, extrapolation_value, crops.data());
|
||||
}
|
||||
return d.ok();
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct CropAndResizeBackpropImage<GPUDevice, T> {
|
||||
void operator()(const GPUDevice& d,
|
||||
bool operator()(const GPUDevice& d,
|
||||
typename TTypes<float, 4>::ConstTensor grads,
|
||||
typename TTypes<float, 2>::ConstTensor boxes,
|
||||
typename TTypes<int32, 1>::ConstTensor box_ind,
|
||||
@ -351,29 +369,35 @@ struct CropAndResizeBackpropImage<GPUDevice, T> {
|
||||
|
||||
// Initialize grads_image with all zeros.
|
||||
total_count = batch * image_height * image_width * depth;
|
||||
config = GetCudaLaunchConfig(total_count, d);
|
||||
SetZero<<<config.block_count, config.thread_per_block, 0, d.stream()>>>(
|
||||
total_count, grads_image.data());
|
||||
if (total_count > 0) {
|
||||
config = GetCudaLaunchConfig(total_count, d);
|
||||
SetZero<<<config.block_count, config.thread_per_block, 0, d.stream()>>>(
|
||||
config.virtual_thread_count, grads_image.data());
|
||||
}
|
||||
|
||||
// Accumulate.
|
||||
total_count = num_boxes * crop_height * crop_width * depth;
|
||||
config = GetCudaLaunchConfig(total_count, d);
|
||||
CropAndResizeBackpropImageKernel<<<
|
||||
config.block_count, config.thread_per_block, 0, d.stream()>>>(
|
||||
config.virtual_thread_count, grads.data(), boxes.data(), box_ind.data(),
|
||||
num_boxes, image_height, image_width, crop_height, crop_width, depth,
|
||||
grads_image.data());
|
||||
if (total_count > 0) {
|
||||
config = GetCudaLaunchConfig(total_count, d);
|
||||
CropAndResizeBackpropImageKernel<<<
|
||||
config.block_count, config.thread_per_block, 0, d.stream()>>>(
|
||||
config.virtual_thread_count, grads.data(), boxes.data(),
|
||||
box_ind.data(), num_boxes, batch, image_height, image_width,
|
||||
crop_height, crop_width, depth, grads_image.data());
|
||||
}
|
||||
return d.ok();
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct CropAndResizeBackpropBoxes<GPUDevice, T> {
|
||||
void operator()(const GPUDevice& d,
|
||||
bool operator()(const GPUDevice& d,
|
||||
typename TTypes<float, 4>::ConstTensor grads,
|
||||
typename TTypes<T, 4>::ConstTensor image,
|
||||
typename TTypes<float, 2>::ConstTensor boxes,
|
||||
typename TTypes<int32, 1>::ConstTensor box_ind,
|
||||
typename TTypes<float, 2>::Tensor grads_boxes) {
|
||||
const int batch = image.dimension(0);
|
||||
const int image_height = image.dimension(1);
|
||||
const int image_width = image.dimension(2);
|
||||
|
||||
@ -387,18 +411,23 @@ struct CropAndResizeBackpropBoxes<GPUDevice, T> {
|
||||
|
||||
// Initialize grads_boxes with all zeros.
|
||||
total_count = num_boxes * 4;
|
||||
config = GetCudaLaunchConfig(total_count, d);
|
||||
SetZero<<<config.block_count, config.thread_per_block, 0, d.stream()>>>(
|
||||
total_count, grads_boxes.data());
|
||||
if (total_count > 0) {
|
||||
config = GetCudaLaunchConfig(total_count, d);
|
||||
SetZero<<<config.block_count, config.thread_per_block, 0, d.stream()>>>(
|
||||
config.virtual_thread_count, grads_boxes.data());
|
||||
}
|
||||
|
||||
// Accumulate.
|
||||
total_count = num_boxes * crop_height * crop_width * depth;
|
||||
config = GetCudaLaunchConfig(total_count, d);
|
||||
CropAndResizeBackpropBoxesKernel<<<
|
||||
config.block_count, config.thread_per_block, 0, d.stream()>>>(
|
||||
config.virtual_thread_count, grads.data(), image.data(), boxes.data(),
|
||||
box_ind.data(), num_boxes, image_height, image_width, crop_height,
|
||||
crop_width, depth, grads_boxes.data());
|
||||
if (total_count > 0) {
|
||||
config = GetCudaLaunchConfig(total_count, d);
|
||||
CropAndResizeBackpropBoxesKernel<<<
|
||||
config.block_count, config.thread_per_block, 0, d.stream()>>>(
|
||||
config.virtual_thread_count, grads.data(), image.data(), boxes.data(),
|
||||
box_ind.data(), num_boxes, batch, image_height, image_width,
|
||||
crop_height, crop_width, depth, grads_boxes.data());
|
||||
}
|
||||
return d.ok();
|
||||
}
|
||||
};
|
||||
|
||||
@ -407,7 +436,7 @@ struct CropAndResizeBackpropBoxes<GPUDevice, T> {
|
||||
template struct CropAndResizeBackpropImage<GPUDevice, T>; \
|
||||
template struct CropAndResizeBackpropBoxes<GPUDevice, T>;
|
||||
|
||||
TF_CALL_GPU_NUMBER_TYPES_NO_HALF(DEFINE_GPU_SPECS);
|
||||
TF_CALL_float(DEFINE_GPU_SPECS);
|
||||
|
||||
#undef DEFINE_GPU_SPECS
|
||||
|
||||
|
@ -189,6 +189,24 @@ TEST_F(CropAndResizeOpTest, TestCropAndResize2x2To3x3Extrapolated) {
|
||||
test::ExpectTensorEqual<float>(expected, *GetOutput(0));
|
||||
}
|
||||
|
||||
TEST_F(CropAndResizeOpTest, TestCropAndResize2x2To3x3NoCrop) {
|
||||
MakeOp(0);
|
||||
// Input:
|
||||
// 1, 2
|
||||
// 3, 4
|
||||
AddInputFromArray<float>(TensorShape({1, 2, 2, 1}), {1, 2, 3, 4});
|
||||
AddInputFromArray<float>(TensorShape({0, 4}), {});
|
||||
AddInputFromArray<int32>(TensorShape({0}), {});
|
||||
AddInputFromArray<int32>(TensorShape({2}), {3, 3});
|
||||
TF_ASSERT_OK(RunOpKernel());
|
||||
|
||||
Tensor expected(allocator(), DT_FLOAT, TensorShape({0, 3, 3, 1}));
|
||||
// clang-format off
|
||||
test::FillValues<float>(&expected, {});
|
||||
// clang-format on
|
||||
test::ExpectTensorEqual<float>(expected, *GetOutput(0));
|
||||
}
|
||||
|
||||
TEST_F(CropAndResizeOpTest, TestInvalidInputShape) {
|
||||
MakeOp(0);
|
||||
AddInputFromArray<float>(TensorShape({2, 2, 1}), {1, 2, 3, 4});
|
||||
@ -201,6 +219,19 @@ TEST_F(CropAndResizeOpTest, TestInvalidInputShape) {
|
||||
<< s;
|
||||
}
|
||||
|
||||
TEST_F(CropAndResizeOpTest, TestInvalidBoxIndexShape) {
|
||||
MakeOp(0);
|
||||
AddInputFromArray<float>(TensorShape({1, 2, 2, 1}), {1, 2, 3, 4});
|
||||
AddInputFromArray<float>(TensorShape({1, 4}), {0, 0, 1, 1});
|
||||
AddInputFromArray<int32>(TensorShape({2}), {0, 0});
|
||||
AddInputFromArray<int32>(TensorShape({2}), {4, 4});
|
||||
Status s = RunOpKernel();
|
||||
ASSERT_FALSE(s.ok());
|
||||
EXPECT_TRUE(
|
||||
StringPiece(s.ToString()).contains("box_ind has incompatible shape"))
|
||||
<< s;
|
||||
}
|
||||
|
||||
TEST_F(CropAndResizeOpTest, TestInvalidBoxIndex) {
|
||||
MakeOp(0);
|
||||
AddInputFromArray<float>(TensorShape({1, 2, 2, 1}), {1, 2, 3, 4});
|
||||
|
@ -253,54 +253,53 @@ class CropAndResizeOpTest(tf.test.TestCase):
|
||||
radius = 2 * delta
|
||||
low, high = -0.5, 1.5 # Also covers the case of extrapolation.
|
||||
|
||||
for image_height in range(1, 5):
|
||||
for image_width in range(1, 3):
|
||||
for crop_height in range(1, 3):
|
||||
for crop_width in range(2, 4):
|
||||
for depth in range(1, 3):
|
||||
for num_boxes in range(1, 3):
|
||||
image_height = 4
|
||||
for image_width in range(1, 3):
|
||||
for crop_height in range(1, 3):
|
||||
for crop_width in range(2, 4):
|
||||
for depth in range(1, 3):
|
||||
for num_boxes in range(1, 3):
|
||||
|
||||
batch = num_boxes
|
||||
image_shape = [batch, image_height, image_width, depth]
|
||||
crop_size = [crop_height, crop_width]
|
||||
crops_shape = [num_boxes, crop_height, crop_width, depth]
|
||||
boxes_shape = [num_boxes, 4]
|
||||
batch = num_boxes
|
||||
image_shape = [batch, image_height, image_width, depth]
|
||||
crop_size = [crop_height, crop_width]
|
||||
crops_shape = [num_boxes, crop_height, crop_width, depth]
|
||||
boxes_shape = [num_boxes, 4]
|
||||
|
||||
image = np.arange(0, batch * image_height * image_width *
|
||||
depth).reshape(image_shape).astype(np.float32)
|
||||
boxes = []
|
||||
for _ in range(num_boxes):
|
||||
# pylint: disable=unbalanced-tuple-unpacking
|
||||
y1, y2 = self._randomUniformAvoidAnchors(
|
||||
low, high, np.linspace(0, 1, image_height), radius, 2)
|
||||
x1, x2 = self._randomUniformAvoidAnchors(
|
||||
low, high, np.linspace(0, 1, image_width), radius, 2)
|
||||
# pylint: enable=unbalanced-tuple-unpacking
|
||||
boxes.append([y1, x1, y2, x2])
|
||||
image = np.arange(0, batch * image_height * image_width *
|
||||
depth).reshape(image_shape).astype(np.float32)
|
||||
boxes = []
|
||||
for _ in range(num_boxes):
|
||||
# pylint: disable=unbalanced-tuple-unpacking
|
||||
y1, y2 = self._randomUniformAvoidAnchors(
|
||||
low, high, np.linspace(0, 1, image_height), radius, 2)
|
||||
x1, x2 = self._randomUniformAvoidAnchors(
|
||||
low, high, np.linspace(0, 1, image_width), radius, 2)
|
||||
# pylint: enable=unbalanced-tuple-unpacking
|
||||
boxes.append([y1, x1, y2, x2])
|
||||
|
||||
boxes = np.array(boxes, dtype=np.float32)
|
||||
box_ind = np.arange(batch, dtype=np.int32)
|
||||
boxes = np.array(boxes, dtype=np.float32)
|
||||
box_ind = np.arange(batch, dtype=np.int32)
|
||||
|
||||
for use_gpu in [False, True]:
|
||||
with self.test_session(use_gpu=use_gpu):
|
||||
image_tensor = tf.constant(image, shape=image_shape)
|
||||
boxes_tensor = tf.constant(boxes, shape=[num_boxes, 4])
|
||||
box_ind_tensor = tf.constant(box_ind, shape=[num_boxes])
|
||||
crops = tf.image.crop_and_resize(
|
||||
image_tensor,
|
||||
boxes_tensor,
|
||||
box_ind_tensor,
|
||||
tf.constant(crop_size, shape=[2]))
|
||||
for use_gpu in [False, True]:
|
||||
with self.test_session(use_gpu=use_gpu):
|
||||
image_tensor = tf.constant(image, shape=image_shape)
|
||||
boxes_tensor = tf.constant(boxes, shape=[num_boxes, 4])
|
||||
box_ind_tensor = tf.constant(box_ind, shape=[num_boxes])
|
||||
crops = tf.image.crop_and_resize(
|
||||
image_tensor,
|
||||
boxes_tensor,
|
||||
box_ind_tensor,
|
||||
tf.constant(crop_size, shape=[2]))
|
||||
|
||||
err = tf.test.compute_gradient_error(
|
||||
[image_tensor, boxes_tensor],
|
||||
[image_shape, boxes_shape],
|
||||
crops,
|
||||
crops_shape,
|
||||
delta=delta,
|
||||
x_init_value=[image, boxes])
|
||||
err = tf.test.compute_gradient_error(
|
||||
[image_tensor, boxes_tensor], [image_shape, boxes_shape],
|
||||
crops,
|
||||
crops_shape,
|
||||
delta=delta,
|
||||
x_init_value=[image, boxes])
|
||||
|
||||
self.assertLess(err, 2e-3)
|
||||
self.assertLess(err, 2e-3)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
Loading…
Reference in New Issue
Block a user